In [1]:
from os import listdir

import time
import json
import copy

import torch

from torchvision.datasets import mnist, FashionMNIST, CIFAR10, CIFAR100
from torch.utils.data import DataLoader
from torchvision.transforms import ToTensor
from torch.optim import SGD, Adam
from torch.nn import Module
from torch import nn
from torch.nn import CrossEntropyLoss
from torchvision.models.resnet import Bottleneck, ResNet
from torchvision import datasets, models, transforms

from wilds import get_dataset
from wilds.common.data_loaders import get_train_loader
import torchvision.transforms as transforms

import numpy as np
import pandas as pd

from openood.evaluators import metrics

In [2]:
print(torch.version.cuda)

10.2


In [3]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print(device)

cuda:0


In [4]:
!pwd

/home/rdr2143/oodn-final-project/OpenOOD-nndl


### Supported Activation Functions

For activation functions, we are considering ReLU, Softplus, Swish. *Note that we may conduct experiments for a subset based on the compute resources available*

In [5]:
def get_activation_fn(activation):
    if activation == 'relu':
        return nn.ReLU()
    elif activation == 'softplus':
        return nn.Softplus()
    elif activation == 'swish':
        return nn.Swish()
    return None

### LeNet

In [6]:
class LeNet(nn.Module):
    def __init__(self, num_classes, num_channel=3, activation='relu'):
        super(LeNet, self).__init__()
        self.num_classes = num_classes
        self.feature_size = 84
        self.block1 = nn.Sequential(
            nn.Conv2d(in_channels=num_channel,
                      out_channels=6,
                      kernel_size=5,
                      stride=1,
                      padding=2), get_activation_fn(activation), nn.MaxPool2d(kernel_size=2))

        self.block2 = nn.Sequential(
            nn.Conv2d(in_channels=6, out_channels=16, kernel_size=5, stride=1),
             get_activation_fn(activation), nn.MaxPool2d(kernel_size=2))

        self.block3 = nn.Sequential(
            nn.Conv2d(in_channels=16,
                      out_channels=120,
                      kernel_size=5,
                      stride=1), get_activation_fn(activation))

        self.classifier1 = nn.Linear(in_features=120, out_features=84)
        self.relu = get_activation_fn(activation)
        self.fc = nn.Linear(in_features=84, out_features=num_classes)

    def get_fc(self):
        fc = self.fc
        return fc.weight.cpu().detach().numpy(), fc.bias.cpu().detach().numpy()

    def forward(self, x, return_feature=False, return_feature_list=False):
        feature1 = self.block1(x)
        feature2 = self.block2(feature1)
        feature3 = self.block3(feature2)
        feature3 = feature3.view(feature3.shape[0], -1)
        feature = self.relu(self.classifier1(feature3))
        logits_cls = self.fc(feature)
        feature_list = [feature1, feature2, feature3, feature]
        if return_feature:
            return logits_cls, feature
        elif return_feature_list:
            return logits_cls, feature_list
        else:
            return logits_cls

    def forward_threshold(self, x, threshold):
        feature1 = self.block1(x)
        feature2 = self.block2(feature1)
        feature3 = self.block3(feature2)
        feature3 = feature3.view(feature3.shape[0], -1)
        feature = self.relu(self.classifier1(feature3))
        feature = feature.clip(max=threshold)
        logits_cls = self.fc(feature)

        return logits_cls

### ResNet50

In [7]:
def set_parameter_requires_grad(model):
    for name,param in model.named_parameters():
        if not (name.startswith('layer4') or name.startswith('fc')):
            param.requires_grad = False

def get_resnet_model(activation_function_type, n_classes, use_pretrained=True):
    resnet_model = models.resnet50(pretrained=use_pretrained)
    
    # if we use pretrained, then freeze the corresponding layers
    if use_pretrained:
        set_parameter_requires_grad(resnet_model, feature_extract)

    set_activation_function(resnet_model,activation_function_type)
    num_ftrs = resnet_model.fc.in_features
    resnet_model.fc = nn.Linear(num_ftrs, n_classes)
    resnet_model.to(device)
    return resnet_model

def set_activation_function(resnet_model, activation_function_type):
    resnet_model.relu = get_activation_fn(activation_function_type)
    resnet_model.layer1[0].relu = get_activation_fn(activation_function_type)
    resnet_model.layer1[1].relu = get_activation_fn(activation_function_type)
    resnet_model.layer1[2].relu = get_activation_fn(activation_function_type)

    resnet_model.layer2[0].relu = get_activation_fn(activation_function_type)
    resnet_model.layer2[1].relu = get_activation_fn(activation_function_type)
    resnet_model.layer2[2].relu = get_activation_fn(activation_function_type)
    resnet_model.layer2[3].relu = get_activation_fn(activation_function_type)

    resnet_model.layer3[0].relu = get_activation_fn(activation_function_type)
    resnet_model.layer3[1].relu = get_activation_fn(activation_function_type)
    resnet_model.layer3[2].relu = get_activation_fn(activation_function_type)
    resnet_model.layer3[3].relu = get_activation_fn(activation_function_type)
    resnet_model.layer3[4].relu = get_activation_fn(activation_function_type)
    resnet_model.layer3[5].relu = get_activation_fn(activation_function_type)


    resnet_model.layer4[0].relu = get_activation_fn(activation_function_type)
    resnet_model.layer4[1].relu = get_activation_fn(activation_function_type)
    resnet_model.layer4[2].relu = get_activation_fn(activation_function_type)

    return resnet_model

In [8]:
def get_model(config):
    activation_function_type = config["activation_function_type"]
    network_type = config["network"]
    n_classes = config["n_classes"]

    if network_type == "lenet":
        model =  LeNet(num_classes=n_classes, num_channel=1, activation=activation_function_type)
    elif network_type == "resnet50":
        model = get_resnet_model(activation_function_type, n_classes, config['pretrained'])
    else:
        raise Exception("Currently we only support lenet or resnet50")

    return model

### Supported Post-Hoc OODN Processors

In [9]:
class OODPostprocessor():
    
    def inference(self, net: nn.Module, data_loader: DataLoader):
        pred_list, conf_list, label_list = [], [], []
        for idx, loaded_data in enumerate(data_loader):
            data, label = loaded_data[0], loaded_data[1]
            if idx % 50 == 0:
                print(f'Performing inference on batch: {idx}')
            pred, conf = self.postprocess(net, data.to(device))
            for idx in range(len(data)):
                pred_list.append(pred[idx].tolist())
                conf_list.append(conf[idx].tolist())
                label_list.append(label[idx].tolist())

        # convert values into numpy array
        pred_list = np.array(pred_list, dtype=int)
        conf_list = np.array(conf_list)
        label_list = np.array(label_list, dtype=int)

        return pred_list, conf_list, label_list

In [10]:
class ODINPostprocessor(OODPostprocessor):
    def __init__(self, temperature, noise):
        super(OODPostprocessor)
        self.temperature = temperature
        self.noise = noise
        
    def postprocess(self, net: nn.Module, data):
        net.eval()
        data.requires_grad = True
        output = net(data)

        # Calculating the perturbation we need to add, that is,
        # the sign of gradient of cross entropy loss w.r.t. input
        criterion = nn.CrossEntropyLoss()

        labels = output.detach().argmax(axis=1)

        # Using temperature scaling
        output = output / self.temperature

        loss = criterion(output, labels)
        loss.backward()

        # Normalizing the gradient to binary in {0, 1}
        gradient = torch.ge(data.grad.detach(), 0)
        gradient = (gradient.float() - 0.5) * 2

        # Scaling values taken from original code       
        gradient[:, 0] = (gradient[:, 0]) / (63.0 / 255.0)
        if gradient.shape[1] == 3:
            gradient[:, 1] = (gradient[:, 1]) / (62.1 / 255.0)
            gradient[:, 2] = (gradient[:, 2]) / (66.7 / 255.0)

        # Adding small perturbations to images
        tempInputs = torch.add(data.detach(), gradient, alpha=-self.noise)
        output = net(tempInputs)
        output = output / self.temperature

        # Calculating the confidence after adding perturbations
        nnOutput = output.detach()
        nnOutput = nnOutput - nnOutput.max(dim=1, keepdims=True).values
        nnOutput = nnOutput.exp() / nnOutput.exp().sum(dim=1, keepdims=True)

        conf, pred = nnOutput.max(dim=1)

        return pred, conf

In [11]:
class MCDPostprocessor(OODPostprocessor):
    def __init__(self, samples: int = 30):
        super(OODPostprocessor)
        self.samples = samples  #: number :math:`N` of samples

    def postprocess(self, model: torch.nn.Module, x: torch.Tensor) -> torch.Tensor:
        mode_switch = False
        if not model.training:
            mode_switch = True

            model.train()

            for mod in model.modules():
                # reset batch norm layers.
                # TODO: are there other layers?
                if isinstance(mod, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)):
                    mod.train(False)

        results = None
        with torch.no_grad():
            for i in range(self.samples):
                output = model(x).softmax(dim=1)
                if results is None:
                    results = torch.zeros(size=output.shape).to(device)
                results += output
        results /= self.samples

        if mode_switch:
            model.eval()
        
        conf, pred = results.max(dim=1)

        return pred, conf

In [12]:
def get_postprocessor(postprocessor_type="odin"):
    if postprocessor_type == "odin":
        postprocessor = ODINPostprocessor(1000, 0.0014)
    elif postprocessor_type == "mcd":
        postprocessor = MCDPostprocessor(30)
    return postprocessor

### Supported Out of Distribution Detection Metrics

What metrics do we specifically care about here?

**FPR@95** measures the false positive rate (FPR) when the true positive rate (TPR) is
equal to 95%. Lower scores indicate better performance.

**AUROC** measures the area under the
Receiver Operating Characteristic (ROC) curve, which displays the relationship between TPR and
FPR. The area under the ROC curve can be interpreted as the probability that a positive ID example
will have a higher detection score than a negative OOD example.

**AUPR** measures the area under
the Precision-Recall (PR) curve. The PR curve is created by plotting precision versus recall. Similar
to AUROC, we consider ID samples as positive, so that the score corresponds to the AUPR-In metric
in some works

In [13]:
def calculate_oodn_metrics(model, postprocessor_type, id_test_loader, ood_test_loader, ood_name):
    postprocessor = get_postprocessor(postprocessor_type)
    id_pred, id_conf, id_gt = postprocessor.inference(
                model, id_test_loader)

    ood_pred, ood_conf, ood_gt = postprocessor.inference(
        model, ood_test_loader)

    ood_gt = -1 * np.ones_like(ood_gt)  # hard set to -1 as ood
    pred = np.concatenate([id_pred, ood_pred])
    conf = np.concatenate([id_conf, ood_conf])
    label = np.concatenate([id_gt, ood_gt])
    ood_metrics = metrics.compute_all_metrics(conf, label, pred)

    return print_and_get_formatted_metrics(ood_metrics, ood_name)

def print_and_get_formatted_metrics(metrics, dataset_name):
    [fpr, auroc, aupr_in, aupr_out,
     ccr_4, ccr_3, ccr_2, ccr_1, accuracy] \
     = metrics

    write_content = {
        'dataset': dataset_name,
        'FPR@95': '{:.2f}'.format(100 * fpr),
        'AUROC': '{:.2f}'.format(100 * auroc),
        'AUPR_IN': '{:.2f}'.format(100 * aupr_in),
        'AUPR_OUT': '{:.2f}'.format(100 * aupr_out),
        'CCR_4': '{:.2f}'.format(100 * ccr_4),
        'CCR_3': '{:.2f}'.format(100 * ccr_3),
        'CCR_2': '{:.2f}'.format(100 * ccr_2),
        'CCR_1': '{:.2f}'.format(100 * ccr_1),
        'ACC': '{:.2f}'.format(100 * accuracy)
    }

    fieldnames = list(write_content.keys())

    # print ood metric results
    print('FPR@95: {:.2f}, AUROC: {:.2f}'.format(100 * fpr, 100 * auroc),
          end=' ',
          flush=True)
    print('AUPR_IN: {:.2f}, AUPR_OUT: {:.2f}'.format(
        100 * aupr_in, 100 * aupr_out),
          flush=True)
    print('CCR: {:.2f}, {:.2f}, {:.2f}, {:.2f},'.format(
        ccr_4 * 100, ccr_3 * 100, ccr_2 * 100, ccr_1 * 100),
          end=' ',
          flush=True)
    print('ACC: {:.2f}'.format(accuracy * 100), flush=True)
    print(u'\u2500' * 70, flush=True)
    return write_content

def load_results_into_df(dir_path):
    res_files = [dir_path+each for each in listdir(dir_path)]
    all_results = []
    columns = ['optimizer_type', 'activation_function_type', 'postprocessor_type', 'trial', 'AUROC']
    for fp in res_files:
        f = open(fp)
        data = json.load(f)
        for trial, results in data.items():
            all_results.append([
                    results['optimizer_type'],
                    results['activation_function_type'],
                    results['postprocessor_type'],
                    trial,
                    float(results['AUROC'])
                ])
    df = pd.DataFrame(all_results, columns=columns)
    return df

In [21]:
def get_optimizer(model, config):
    params = model.parameters()
    lr = config['lr']
    momentum = config['momentum']
    weight_decay = config['weight_decay']
    optimizer_type = config['optimizer_type']

    print(f'Getting optimizer for type: {optimizer_type}...')
    if optimizer_type == 'SGD':
        return SGD(params,
              lr=lr,
              momentum=momentum,
              weight_decay=weight_decay)
    elif optimizer_type == 'Adam':
        return Adam(params,
                    lr=lr,
                    weight_decay=weight_decay)
    else:
        raise Exception("Invalid optimizer_type provided, only SGD and Adam are supported currently")

def get_wilds_loader(dataset, split, batch_size):
    d = dataset.get_subset(
        split,
        # frac=0.1,
        transform=transforms.Compose(
            [transforms.Resize((448, 448)), transforms.ToTensor()]
        ),
    )
    # Prepare the standard data loader
    return get_train_loader("standard", d, batch_size=batch_size, num_workers=4)

def get_data_loaders(config):
    data_loaders = {}
    dataset_name = config["dataset_name"]
    dataset_type = config["dataset_type"]
    batch_size = config['batch_size']

    wilds_id_test_split = "id_val" if dataset_name == "camelyon17" else "id_test"
    if dataset_type == "wilds":
        # wilds dataset
        dataset = get_dataset(dataset=dataset_name, download=True)
        data_loaders["train"] = get_wilds_loader(dataset, "train", batch_size)
        data_loaders["ood_test"] = get_wilds_loader(dataset, "test", batch_size)
        data_loaders["id_test"] = get_wilds_loader(dataset, wilds_id_test_split, batch_size)
        return
    elif dataset_name == "cifar":
        transform_train = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
        ])

        transform_test = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
        ])
        train_dataset = CIFAR10(root='data', download=True, train=True, transform=transform_train)
        test_dataset = CIFAR10(root='data', download=True, train=False, transform=transform_test)
        ood_test_dataset = CIFAR100(root='data', download=True, train=False, transform=transform_test)
    elif dataset_name == "mnist":
        # mnist dataset
        train_dataset = mnist.MNIST(root='data', download=True, train=True, transform=ToTensor())
        test_dataset = mnist.MNIST(root='data', download=True, train=False, transform=ToTensor())
        ood_test_dataset = mnist.FashionMNIST(root='data', download=True,train=False,transform=ToTensor())

    data_loaders["train"] = DataLoader(train_dataset, batch_size=batch_size)
    data_loaders["id_test"] = DataLoader(test_dataset, batch_size=batch_size)
    data_loaders["ood_test"] = DataLoader(ood_test_dataset, batch_size=batch_size)

    return data_loaders

In [22]:
def train_resnet_model_given_opti_activation_fn(config):
    # get the train loader
    train_loader = config["data_loaders"]["train"]

    # get the resnet model with the replaced activation functions
    model = get_model(config)
    model.to(device)

    # get the optimizer
    sgd = get_optimizer(model, config)

    loss_fn = CrossEntropyLoss()
    for current_epoch in range(config['epochs']):
        tic=time.time()
        per_batch_time = time.time()
        model.train()
        print('Training epoch: {}'.format(current_epoch))
        for idx, (loader_data) in enumerate(train_loader):
            train_x, train_label = loader_data[0].to(device), loader_data[1].to(device)
            sgd.zero_grad()
            predict_y = model(train_x.float())
            loss = loss_fn(predict_y, train_label.long())
            if idx % 100 == 0:
                print('idx: {}, loss: {} time take: {}'.format(idx, loss.sum().item(), time.time() - per_batch_time))
                per_batch_time = time.time()
            loss.backward()
            sgd.step()
        print(f"epoch {current_epoch} time taken: {time.time()-tic}s")
    torch.save(model, config['model_name'])

    return model

def run_full_oodn_pipeline(config):
    metrics = {}
    for i in range(config["trials"]):
        model_name = f"models/{config['dataset_name']}_{config['network']}_{config['postprocessor_type']}_{config['activation_function_type']}_{config['optimizer_type']}_{i}.pkl"
        print(f'Running model: {model_name}...')
        config['model_name'] = model_name
        # train model
        model = train_resnet_model_given_opti_activation_fn(config)
        # calculate oodn metrics
        metrics[i] = calculate_oodn_metrics(model,
                               config['postprocessor_type'],
                               config["data_loaders"]["id_test"],
                               config["data_loaders"]["ood_test"],
                               config["dataset_name"])
        metrics[i]['optimizer_type'] = config['optimizer_type']
        metrics[i]['activation_function_type'] = config['activation_function_type']
        metrics[i]['postprocessor_type'] = config['postprocessor_type']

    experiment_name = f"{config['results_dir']}/{config['dataset_name']}_{config['network']}_{config['postprocessor_type']}_{config['activation_function_type']}_{config['optimizer_type']}.json"
    with open(experiment_name, 'w') as fp:
        json.dump(metrics, fp)
    return metrics

### Study 2: Resnet, CIFAR-10 (ID), CIFAR-100 (OOD)

#### Study 2(a): Adam + ReLU + Odin

In [23]:
config_cifar_adam_relu_odin = {
    "batch_size": 32,
    "n_classes": 10,
    "dataset_name": "cifar",
    "epochs": 100,
    "version": time.time(),
    "lr": 0.01,
    "momentum": 0.9,
    "weight_decay": 0.0005,
    "optimizer_type": "Adam",
    "activation_function_type": "relu",
    "network": "resnet50",
    "postprocessor_type": "odin",
    "trials": 3,
    "dataset_type": "cifar",
    "results_dir": "cifar10-study",
    "pretrained": False
} 
config_cifar_adam_relu_odin["data_loaders"] = get_data_loaders(config_cifar_adam_relu_odin)
run_full_oodn_pipeline(config_cifar_adam_relu_odin)

Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Running model: models/cifar_resnet50_odin_relu_Adam_0.pkl...
Getting optimizer for type: Adam...
Training epoch: 0
idx: 0, loss: 2.642430305480957 time take: 0.024790525436401367
idx: 100, loss: 4.273593902587891 time take: 4.180047273635864
idx: 200, loss: 2.217350482940674 time take: 4.183010101318359
idx: 300, loss: 2.5634069442749023 time take: 4.17393159866333
idx: 400, loss: 1.8470849990844727 time take: 4.197870492935181
idx: 500, loss: 2.0578439235687256 time take: 4.173777341842651
idx: 600, loss: 2.336688756942749 time take: 4.176985025405884
idx: 700, loss: 2.0894312858581543 time take: 4.166748762130737
idx: 800, loss: 2.3092563152313232 time take: 4.266041040420532
idx: 900, loss: 1.912632942199707 time take: 4.179208040237427
idx: 1000, loss: 1.8452318906784058 time take: 4.132830619812012
idx: 1100, loss: 1.9045339822769165 time take: 4.162060737609863
idx: 1

idx: 700, loss: 1.768478274345398 time take: 4.101844549179077
idx: 800, loss: 1.6334450244903564 time take: 4.12708854675293
idx: 900, loss: 1.3116384744644165 time take: 4.131580591201782
idx: 1000, loss: 1.5481717586517334 time take: 4.133164405822754
idx: 1100, loss: 1.3579862117767334 time take: 4.1207897663116455
idx: 1200, loss: 1.616591215133667 time take: 4.107295274734497
idx: 1300, loss: 1.7291995286941528 time take: 4.1147141456604
idx: 1400, loss: 1.6823824644088745 time take: 4.239807367324829
idx: 1500, loss: 1.8291126489639282 time take: 4.107453107833862
epoch 7 time taken: 64.45949411392212s
Training epoch: 8
idx: 0, loss: 1.6590707302093506 time take: 0.023249149322509766
idx: 100, loss: 1.5848557949066162 time take: 4.166579484939575
idx: 200, loss: 1.4233311414718628 time take: 4.122305870056152
idx: 300, loss: 1.7282540798187256 time take: 4.111442565917969
idx: 400, loss: 1.3581236600875854 time take: 4.1342902183532715
idx: 500, loss: 1.4616546630859375 time tak

idx: 100, loss: 1.3700780868530273 time take: 4.119949579238892
idx: 200, loss: 1.504178524017334 time take: 4.10373854637146
idx: 300, loss: 1.520316481590271 time take: 4.099982500076294
idx: 400, loss: 1.1981079578399658 time take: 4.116159200668335
idx: 500, loss: 1.6862860918045044 time take: 4.248613357543945
idx: 600, loss: 1.7352010011672974 time take: 4.100274324417114
idx: 700, loss: 1.673231601715088 time take: 4.118092775344849
idx: 800, loss: 1.5503705739974976 time take: 4.0982019901275635
idx: 900, loss: 1.2449883222579956 time take: 4.1225714683532715
idx: 1000, loss: 1.6015021800994873 time take: 4.154203653335571
idx: 1100, loss: 1.3361696004867554 time take: 4.129671812057495
idx: 1200, loss: 1.4449882507324219 time take: 4.1100709438323975
idx: 1300, loss: 1.2082690000534058 time take: 4.126873731613159
idx: 1400, loss: 1.4492584466934204 time take: 4.1291680335998535
idx: 1500, loss: 1.9609289169311523 time take: 4.14717435836792
epoch 15 time taken: 64.51310515403

idx: 1100, loss: 1.2494267225265503 time take: 4.1029298305511475
idx: 1200, loss: 1.3682093620300293 time take: 4.25761079788208
idx: 1300, loss: 1.3230633735656738 time take: 4.1253862380981445
idx: 1400, loss: 1.3206459283828735 time take: 4.112245321273804
idx: 1500, loss: 1.7642399072647095 time take: 4.136320352554321
epoch 22 time taken: 64.5363404750824s
Training epoch: 23
idx: 0, loss: 1.8684697151184082 time take: 0.022437095642089844
idx: 100, loss: 1.2242227792739868 time take: 4.119815826416016
idx: 200, loss: 1.3838603496551514 time take: 4.147754192352295
idx: 300, loss: 1.9519283771514893 time take: 4.122519016265869
idx: 400, loss: 1.1979789733886719 time take: 4.1149842739105225
idx: 500, loss: 1.4748904705047607 time take: 4.155505418777466
idx: 600, loss: 1.507365107536316 time take: 4.148080110549927
idx: 700, loss: 1.730521559715271 time take: 4.165325880050659
idx: 800, loss: 1.337943434715271 time take: 4.131303071975708
idx: 900, loss: 1.0752066373825073 time t

idx: 400, loss: 1.1737499237060547 time take: 4.093587875366211
idx: 500, loss: 1.3312734365463257 time take: 4.0967018604278564
idx: 600, loss: 1.613966941833496 time take: 4.090072870254517
idx: 700, loss: 1.4541596174240112 time take: 4.118129014968872
idx: 800, loss: 1.4547381401062012 time take: 4.114704608917236
idx: 900, loss: 1.186826467514038 time take: 4.128931045532227
idx: 1000, loss: 1.345072627067566 time take: 4.132929801940918
idx: 1100, loss: 1.1586062908172607 time take: 4.149299383163452
idx: 1200, loss: 1.3948030471801758 time take: 4.1006176471710205
idx: 1300, loss: 1.3864014148712158 time take: 4.121099233627319
idx: 1400, loss: 1.5418606996536255 time take: 4.119306564331055
idx: 1500, loss: 1.608499526977539 time take: 4.13366174697876
epoch 30 time taken: 64.4759612083435s
Training epoch: 31
idx: 0, loss: 1.5346505641937256 time take: 0.02272200584411621
idx: 100, loss: 1.3963303565979004 time take: 4.120968818664551
idx: 200, loss: 1.3838149309158325 time tak

idx: 1400, loss: 1.264801025390625 time take: 4.132538318634033
idx: 1500, loss: 1.6242222785949707 time take: 4.10059928894043
epoch 37 time taken: 64.57576537132263s
Training epoch: 38
idx: 0, loss: 1.5455503463745117 time take: 0.022011995315551758
idx: 100, loss: 1.4003320932388306 time take: 4.099682807922363
idx: 200, loss: 1.3846673965454102 time take: 4.105387449264526
idx: 300, loss: 1.5341720581054688 time take: 4.149273633956909
idx: 400, loss: 1.233416199684143 time take: 4.136183261871338
idx: 500, loss: 1.419564127922058 time take: 4.121412754058838
idx: 600, loss: 1.5921952724456787 time take: 4.145097255706787
idx: 700, loss: 1.677337646484375 time take: 4.132298707962036
idx: 800, loss: 1.3544764518737793 time take: 4.127435922622681
idx: 900, loss: 1.1305724382400513 time take: 4.240591049194336
idx: 1000, loss: 1.498399257659912 time take: 4.156368255615234
idx: 1100, loss: 1.1651551723480225 time take: 4.138692855834961
idx: 1200, loss: 1.5127055644989014 time take:

idx: 700, loss: 1.674086570739746 time take: 4.0899200439453125
idx: 800, loss: 1.5746772289276123 time take: 4.094297170639038
idx: 900, loss: 1.114585518836975 time take: 4.135002374649048
idx: 1000, loss: 1.7901983261108398 time take: 4.107908487319946
idx: 1100, loss: 1.136747121810913 time take: 4.124406576156616
idx: 1200, loss: 1.4279919862747192 time take: 4.13542366027832
idx: 1300, loss: 1.4166609048843384 time take: 4.114991188049316
idx: 1400, loss: 1.5053040981292725 time take: 4.113423585891724
idx: 1500, loss: 1.8221560716629028 time take: 4.125681161880493
epoch 45 time taken: 64.50027394294739s
Training epoch: 46
idx: 0, loss: 1.4097082614898682 time take: 0.022697925567626953
idx: 100, loss: 1.49616539478302 time take: 4.135374546051025
idx: 200, loss: 1.2566032409667969 time take: 4.085765361785889
idx: 300, loss: 1.5999573469161987 time take: 4.087802886962891
idx: 400, loss: 1.2826142311096191 time take: 4.115556001663208
idx: 500, loss: 1.5308506488800049 time tak

idx: 100, loss: 1.365668535232544 time take: 4.073619365692139
idx: 200, loss: 1.5495373010635376 time take: 4.081064701080322
idx: 300, loss: 1.53732430934906 time take: 4.088383913040161
idx: 400, loss: 0.9994457364082336 time take: 4.152534246444702
idx: 500, loss: 1.5843393802642822 time take: 4.174293279647827
idx: 600, loss: 1.6086411476135254 time take: 4.109976530075073
idx: 700, loss: 1.6312544345855713 time take: 4.224348068237305
idx: 800, loss: 1.4486466646194458 time take: 4.090992212295532
idx: 900, loss: 1.178405523300171 time take: 4.114879131317139
idx: 1000, loss: 1.5972552299499512 time take: 4.103013515472412
idx: 1100, loss: 1.2611333131790161 time take: 4.078481435775757
idx: 1200, loss: 1.4301141500473022 time take: 4.109015703201294
idx: 1300, loss: 1.354706883430481 time take: 4.084815502166748
idx: 1400, loss: 1.682235836982727 time take: 4.077078342437744
idx: 1500, loss: 2.030822515487671 time take: 4.071633577346802
epoch 53 time taken: 64.20370650291443s
T

idx: 1100, loss: 1.454136610031128 time take: 4.096735715866089
idx: 1200, loss: 1.443176507949829 time take: 4.08686375617981
idx: 1300, loss: 1.245572805404663 time take: 4.091705083847046
idx: 1400, loss: 1.3735440969467163 time take: 4.140681505203247
idx: 1500, loss: 1.7583950757980347 time take: 4.211449384689331
epoch 60 time taken: 64.23724961280823s
Training epoch: 61
idx: 0, loss: 1.493891954421997 time take: 0.02232217788696289
idx: 100, loss: 1.5603883266448975 time take: 4.105429172515869
idx: 200, loss: 1.3170554637908936 time take: 4.100595712661743
idx: 300, loss: 1.5861196517944336 time take: 4.1045801639556885
idx: 400, loss: 1.2918680906295776 time take: 4.085706949234009
idx: 500, loss: 1.4713033437728882 time take: 4.080963850021362
idx: 600, loss: 1.4909604787826538 time take: 4.082334518432617
idx: 700, loss: 1.6435216665267944 time take: 4.115860939025879
idx: 800, loss: 1.5407732725143433 time take: 4.097870111465454
idx: 900, loss: 1.1157302856445312 time take

idx: 400, loss: 1.2313119173049927 time take: 4.109345197677612
idx: 500, loss: 1.3832876682281494 time take: 4.0728442668914795
idx: 600, loss: 1.618671178817749 time take: 4.198811292648315
idx: 700, loss: 1.4910825490951538 time take: 4.130047798156738
idx: 800, loss: 1.4914301633834839 time take: 4.084486961364746
idx: 900, loss: 1.1473805904388428 time take: 4.085716724395752
idx: 1000, loss: 1.5287104845046997 time take: 4.098724126815796
idx: 1100, loss: 1.3201334476470947 time take: 4.1087117195129395
idx: 1200, loss: 1.5898206233978271 time take: 4.1085779666900635
idx: 1300, loss: 1.3642315864562988 time take: 4.094602346420288
idx: 1400, loss: 1.4048980474472046 time take: 4.081398963928223
idx: 1500, loss: 1.8801522254943848 time take: 4.089293718338013
epoch 68 time taken: 64.12809896469116s
Training epoch: 69
idx: 0, loss: 1.6162769794464111 time take: 0.022735595703125
idx: 100, loss: 1.3302042484283447 time take: 4.077907562255859
idx: 200, loss: 1.4842933416366577 time

idx: 1400, loss: 1.451264500617981 time take: 4.207407236099243
idx: 1500, loss: 1.539455533027649 time take: 4.090462923049927
epoch 75 time taken: 64.0879819393158s
Training epoch: 76
idx: 0, loss: 1.7609758377075195 time take: 0.025115013122558594
idx: 100, loss: 1.4350836277008057 time take: 4.10488486289978
idx: 200, loss: 1.3297429084777832 time take: 4.092406272888184
idx: 300, loss: 1.5372908115386963 time take: 4.112916707992554
idx: 400, loss: 1.1935898065567017 time take: 4.10565185546875
idx: 500, loss: 1.5871648788452148 time take: 4.105617046356201
idx: 600, loss: 1.634217381477356 time take: 4.090444564819336
idx: 700, loss: 1.47573983669281 time take: 4.1119842529296875
idx: 800, loss: 1.3648087978363037 time take: 4.079435586929321
idx: 900, loss: 1.187167763710022 time take: 4.093709945678711
idx: 1000, loss: 1.471205234527588 time take: 4.089365720748901
idx: 1100, loss: 1.2477836608886719 time take: 4.108833074569702
idx: 1200, loss: 1.2380672693252563 time take: 4.

idx: 700, loss: 1.5610491037368774 time take: 4.103128671646118
idx: 800, loss: 1.3183410167694092 time take: 4.086705446243286
idx: 900, loss: 1.029597282409668 time take: 4.114576101303101
idx: 1000, loss: 1.7228602170944214 time take: 4.086587190628052
idx: 1100, loss: 1.31631600856781 time take: 4.094069004058838
idx: 1200, loss: 1.4130719900131226 time take: 4.087055206298828
idx: 1300, loss: 1.3407881259918213 time take: 4.097006559371948
idx: 1400, loss: 1.3746768236160278 time take: 4.11085057258606
idx: 1500, loss: 1.848964810371399 time take: 4.10393762588501
epoch 83 time taken: 64.10367059707642s
Training epoch: 84
idx: 0, loss: 1.6916143894195557 time take: 0.022627592086791992
idx: 100, loss: 1.382141351699829 time take: 4.091655969619751
idx: 200, loss: 1.2335231304168701 time take: 4.083616495132446
idx: 300, loss: 1.6961215734481812 time take: 4.121000528335571
idx: 400, loss: 1.2629876136779785 time take: 4.11745285987854
idx: 500, loss: 1.5085761547088623 time take: 

idx: 100, loss: 1.2997312545776367 time take: 4.101655960083008
idx: 200, loss: 1.3222274780273438 time take: 4.094825744628906
idx: 300, loss: 1.4142179489135742 time take: 4.1369524002075195
idx: 400, loss: 1.172485589981079 time take: 4.088444232940674
idx: 500, loss: 1.5088067054748535 time take: 4.088700294494629
idx: 600, loss: 1.643985390663147 time take: 4.086732864379883
idx: 700, loss: 1.3988062143325806 time take: 4.126394033432007
idx: 800, loss: 1.2944550514221191 time take: 4.094105958938599
idx: 900, loss: 0.9424707889556885 time take: 4.085766077041626
idx: 1000, loss: 1.431445837020874 time take: 4.11061954498291
idx: 1100, loss: 1.2894303798675537 time take: 4.107223987579346
idx: 1200, loss: 1.5005284547805786 time take: 4.212416172027588
idx: 1300, loss: 1.5087810754776 time take: 4.107748031616211
idx: 1400, loss: 1.4235495328903198 time take: 4.106525659561157
idx: 1500, loss: 1.7251948118209839 time take: 4.118012428283691
epoch 91 time taken: 64.23635125160217s


idx: 1100, loss: 1.2756342887878418 time take: 4.1065075397491455
idx: 1200, loss: 1.3433047533035278 time take: 4.144544363021851
idx: 1300, loss: 1.5905067920684814 time take: 4.075014114379883
idx: 1400, loss: 1.50370454788208 time take: 4.089400291442871
idx: 1500, loss: 1.808735728263855 time take: 4.096208572387695
epoch 98 time taken: 64.1711540222168s
Training epoch: 99
idx: 0, loss: 1.3642970323562622 time take: 0.022623300552368164
idx: 100, loss: 1.429739236831665 time take: 4.117614030838013
idx: 200, loss: 1.295657753944397 time take: 4.109603643417358
idx: 300, loss: 1.453009843826294 time take: 4.089138031005859
idx: 400, loss: 1.2382599115371704 time take: 4.195043087005615
idx: 500, loss: 1.4387611150741577 time take: 4.0942394733428955
idx: 600, loss: 1.5129754543304443 time take: 4.081867933273315
idx: 700, loss: 1.7377322912216187 time take: 4.08667516708374
idx: 800, loss: 1.428065538406372 time take: 4.083519697189331
idx: 900, loss: 1.0040082931518555 time take: 

idx: 900, loss: 1.1186925172805786 time take: 4.0735039710998535
idx: 1000, loss: 1.3509950637817383 time take: 4.0886571407318115
idx: 1100, loss: 1.3787336349487305 time take: 4.0712525844573975
idx: 1200, loss: 1.4466617107391357 time take: 4.08640193939209
idx: 1300, loss: 1.5809483528137207 time take: 4.136145830154419
idx: 1400, loss: 1.6993741989135742 time take: 4.148746967315674
idx: 1500, loss: 1.7564647197723389 time take: 4.131420373916626
epoch 5 time taken: 64.12180352210999s
Training epoch: 6
idx: 0, loss: 1.61684250831604 time take: 0.022773265838623047
idx: 100, loss: 1.3224231004714966 time take: 4.140617847442627
idx: 200, loss: 1.2860631942749023 time take: 4.125016927719116
idx: 300, loss: 1.5287195444107056 time take: 4.130423307418823
idx: 400, loss: 1.308443546295166 time take: 4.196406126022339
idx: 500, loss: 1.571101188659668 time take: 4.107629060745239
idx: 600, loss: 1.6865750551223755 time take: 4.133272171020508
idx: 700, loss: 1.6361967325210571 time ta

idx: 200, loss: 1.4566234350204468 time take: 4.069654703140259
idx: 300, loss: 1.3638063669204712 time take: 4.106574773788452
idx: 400, loss: 1.2106808423995972 time take: 4.080965280532837
idx: 500, loss: 1.4836872816085815 time take: 4.086673259735107
idx: 600, loss: 1.46699857711792 time take: 4.081318378448486
idx: 700, loss: 1.559906005859375 time take: 4.117892026901245
idx: 800, loss: 1.6295803785324097 time take: 4.0767810344696045
idx: 900, loss: 1.1542876958847046 time take: 4.065668344497681
idx: 1000, loss: 1.3103224039077759 time take: 4.101616382598877
idx: 1100, loss: 1.273632287979126 time take: 4.228163957595825
idx: 1200, loss: 1.3820327520370483 time take: 4.075563192367554
idx: 1300, loss: 1.1619774103164673 time take: 4.072609186172485
idx: 1400, loss: 1.3761345148086548 time take: 4.064053058624268
idx: 1500, loss: 1.7438316345214844 time take: 4.067106008529663
epoch 13 time taken: 63.939327001571655s
Training epoch: 14
idx: 0, loss: 1.6071488857269287 time tak

idx: 1200, loss: 1.506663203239441 time take: 4.097923755645752
idx: 1300, loss: 1.2786933183670044 time take: 4.081552267074585
idx: 1400, loss: 1.1699578762054443 time take: 4.086987257003784
idx: 1500, loss: 1.8167839050292969 time take: 4.124587297439575
epoch 20 time taken: 64.19526839256287s
Training epoch: 21
idx: 0, loss: 1.507240653038025 time take: 0.02247023582458496
idx: 100, loss: 1.4071859121322632 time take: 4.089818716049194
idx: 200, loss: 1.2519948482513428 time take: 4.089057445526123
idx: 300, loss: 1.6047661304473877 time take: 4.249798059463501
idx: 400, loss: 1.0250649452209473 time take: 4.0963897705078125
idx: 500, loss: 1.4554550647735596 time take: 4.113013982772827
idx: 600, loss: 1.6756591796875 time take: 4.06885838508606
idx: 700, loss: 1.6120201349258423 time take: 4.118013143539429
idx: 800, loss: 1.58279550075531 time take: 4.087066650390625
idx: 900, loss: 1.1212069988250732 time take: 4.07938289642334
idx: 1000, loss: 1.4260579347610474 time take: 4.

idx: 500, loss: 1.3523138761520386 time take: 4.071673631668091
idx: 600, loss: 1.72797691822052 time take: 4.070549011230469
idx: 700, loss: 1.5916982889175415 time take: 4.074698209762573
idx: 800, loss: 1.5675177574157715 time take: 4.078209400177002
idx: 900, loss: 1.0572980642318726 time take: 4.068892955780029
idx: 1000, loss: 1.461746096611023 time take: 4.138556003570557
idx: 1100, loss: 1.3351153135299683 time take: 4.2597949504852295
idx: 1200, loss: 1.6047862768173218 time take: 4.075909614562988
idx: 1300, loss: 1.3732532262802124 time take: 4.1219823360443115
idx: 1400, loss: 1.2981023788452148 time take: 4.090404033660889
idx: 1500, loss: 1.7502621412277222 time take: 4.108118295669556
epoch 28 time taken: 64.04892301559448s
Training epoch: 29
idx: 0, loss: 1.3943321704864502 time take: 0.022820234298706055
idx: 100, loss: 1.3444448709487915 time take: 4.095958232879639
idx: 200, loss: 1.5054335594177246 time take: 4.082967281341553
idx: 300, loss: 1.5901062488555908 time

idx: 1500, loss: 1.6163126230239868 time take: 4.10875940322876
epoch 35 time taken: 64.24778914451599s
Training epoch: 36
idx: 0, loss: 1.4603707790374756 time take: 0.02258610725402832
idx: 100, loss: 1.4075108766555786 time take: 4.080106258392334
idx: 200, loss: 1.4231206178665161 time take: 4.103612661361694
idx: 300, loss: 1.5540993213653564 time take: 4.265920400619507
idx: 400, loss: 1.1784193515777588 time take: 4.086280345916748
idx: 500, loss: 1.6638290882110596 time take: 4.073592662811279
idx: 600, loss: 1.6103333234786987 time take: 4.110100030899048
idx: 700, loss: 1.559643030166626 time take: 4.108963251113892
idx: 800, loss: 1.6567903757095337 time take: 4.101899147033691
idx: 900, loss: 1.0616726875305176 time take: 4.0833024978637695
idx: 1000, loss: 1.5493252277374268 time take: 4.0818963050842285
idx: 1100, loss: 1.3916860818862915 time take: 4.100948810577393
idx: 1200, loss: 1.3283442258834839 time take: 4.084386825561523
idx: 1300, loss: 1.2708024978637695 time 

idx: 800, loss: 1.3406676054000854 time take: 4.102266311645508
idx: 900, loss: 1.0348172187805176 time take: 4.116308689117432
idx: 1000, loss: 1.4226934909820557 time take: 4.247059345245361
idx: 1100, loss: 1.3926784992218018 time take: 4.113920211791992
idx: 1200, loss: 1.353508472442627 time take: 4.099497318267822
idx: 1300, loss: 1.2050338983535767 time take: 4.084714889526367
idx: 1400, loss: 1.2760529518127441 time take: 4.099152565002441
idx: 1500, loss: 1.7293211221694946 time take: 4.103497266769409
epoch 43 time taken: 64.19590759277344s
Training epoch: 44
idx: 0, loss: 1.4770230054855347 time take: 0.02198648452758789
idx: 100, loss: 1.3980685472488403 time take: 4.1022961139678955
idx: 200, loss: 1.3646427392959595 time take: 4.127734661102295
idx: 300, loss: 1.7034567594528198 time take: 4.104575157165527
idx: 400, loss: 1.2673006057739258 time take: 4.117753744125366
idx: 500, loss: 1.3976683616638184 time take: 4.0925843715667725
idx: 600, loss: 1.6156225204467773 tim

idx: 100, loss: 1.348804235458374 time take: 4.224057912826538
idx: 200, loss: 1.2694077491760254 time take: 4.111453533172607
idx: 300, loss: 1.519669771194458 time take: 4.08111834526062
idx: 400, loss: 1.2865571975708008 time take: 4.085457801818848
idx: 500, loss: 1.5006365776062012 time take: 4.126481294631958
idx: 600, loss: 1.599854588508606 time take: 4.115265369415283
idx: 700, loss: 1.5595163106918335 time take: 4.097736358642578
idx: 800, loss: 1.4622248411178589 time take: 4.109259366989136
idx: 900, loss: 1.0242921113967896 time take: 4.0975236892700195
idx: 1000, loss: 1.512865424156189 time take: 4.105043649673462
idx: 1100, loss: 1.3167635202407837 time take: 4.086653470993042
idx: 1200, loss: 1.5192910432815552 time take: 4.090983629226685
idx: 1300, loss: 1.3209998607635498 time take: 4.103598356246948
idx: 1400, loss: 1.3283542394638062 time take: 4.091430902481079
idx: 1500, loss: 1.7327183485031128 time take: 4.084613800048828
epoch 51 time taken: 64.18655133247375

idx: 1100, loss: 1.285668134689331 time take: 4.080373048782349
idx: 1200, loss: 1.2304109334945679 time take: 4.073659896850586
idx: 1300, loss: 1.3400615453720093 time take: 4.076532602310181
idx: 1400, loss: 1.4178745746612549 time take: 4.123785734176636
idx: 1500, loss: 1.811737298965454 time take: 4.0876076221466064
epoch 58 time taken: 64.10609936714172s
Training epoch: 59
idx: 0, loss: 1.5040994882583618 time take: 0.023023128509521484
idx: 100, loss: 1.5386472940444946 time take: 4.126091957092285
idx: 200, loss: 1.353058099746704 time take: 4.12346887588501
idx: 300, loss: 1.3921873569488525 time take: 4.102466344833374
idx: 400, loss: 1.1828253269195557 time take: 4.090531349182129
idx: 500, loss: 1.4447622299194336 time take: 4.081971883773804
idx: 600, loss: 1.6849924325942993 time take: 4.076770544052124
idx: 700, loss: 1.5619101524353027 time take: 4.075355768203735
idx: 800, loss: 1.5233874320983887 time take: 4.248363256454468
idx: 900, loss: 1.2297881841659546 time ta

idx: 400, loss: 1.142112135887146 time take: 4.080436706542969
idx: 500, loss: 1.3739863634109497 time take: 4.068546295166016
idx: 600, loss: 1.5887454748153687 time take: 4.095673561096191
idx: 700, loss: 1.6570452451705933 time take: 4.0823047161102295
idx: 800, loss: 1.3387119770050049 time take: 4.1341025829315186
idx: 900, loss: 1.0573644638061523 time take: 4.09366774559021
idx: 1000, loss: 1.4427424669265747 time take: 4.096602916717529
idx: 1100, loss: 1.412557601928711 time take: 4.097280502319336
idx: 1200, loss: 1.5045065879821777 time take: 4.104437589645386
idx: 1300, loss: 1.1836603879928589 time take: 4.120256423950195
idx: 1400, loss: 1.3717710971832275 time take: 4.113883018493652
idx: 1500, loss: 1.8860564231872559 time take: 4.0851662158966064
epoch 66 time taken: 64.27430367469788s
Training epoch: 67
idx: 0, loss: 1.5456141233444214 time take: 0.02241230010986328
idx: 100, loss: 1.4234139919281006 time take: 4.103550434112549
idx: 200, loss: 1.3140400648117065 time

idx: 1400, loss: 1.1820905208587646 time take: 4.079222917556763
idx: 1500, loss: 1.83063542842865 time take: 4.086563587188721
epoch 73 time taken: 64.09207248687744s
Training epoch: 74
idx: 0, loss: 1.4915645122528076 time take: 0.023941516876220703
idx: 100, loss: 1.3732552528381348 time take: 4.092215538024902
idx: 200, loss: 1.3228590488433838 time take: 4.133550643920898
idx: 300, loss: 1.62322199344635 time take: 4.08809232711792
idx: 400, loss: 1.149620532989502 time take: 4.098516464233398
idx: 500, loss: 1.2733745574951172 time take: 4.101291656494141
idx: 600, loss: 1.5763113498687744 time take: 4.103321313858032
idx: 700, loss: 1.5624125003814697 time take: 4.1219682693481445
idx: 800, loss: 1.4155523777008057 time take: 4.227167367935181
idx: 900, loss: 1.1384028196334839 time take: 4.107970714569092
idx: 1000, loss: 1.371417760848999 time take: 4.087357521057129
idx: 1100, loss: 1.403259515762329 time take: 4.0770158767700195
idx: 1200, loss: 1.2967275381088257 time take:

idx: 700, loss: 1.3718923330307007 time take: 4.068787574768066
idx: 800, loss: 1.5072427988052368 time take: 4.092418909072876
idx: 900, loss: 1.2135858535766602 time take: 4.099251985549927
idx: 1000, loss: 1.471216082572937 time take: 4.072847366333008
idx: 1100, loss: 1.4193655252456665 time take: 4.154926300048828
idx: 1200, loss: 1.3474615812301636 time take: 4.084748029708862
idx: 1300, loss: 1.3368028402328491 time take: 4.111368179321289
idx: 1400, loss: 1.371289849281311 time take: 4.0993993282318115
idx: 1500, loss: 1.938759446144104 time take: 4.223873853683472
epoch 81 time taken: 64.26583313941956s
Training epoch: 82
idx: 0, loss: 1.3362051248550415 time take: 0.02266860008239746
idx: 100, loss: 1.5026050806045532 time take: 4.080666780471802
idx: 200, loss: 1.3403866291046143 time take: 4.0798234939575195
idx: 300, loss: 1.5155256986618042 time take: 4.09063720703125
idx: 400, loss: 1.2634036540985107 time take: 4.08085298538208
idx: 500, loss: 1.5383554697036743 time ta

epoch 88 time taken: 64.22328615188599s
Training epoch: 89
idx: 0, loss: 1.529531478881836 time take: 0.022232770919799805
idx: 100, loss: 1.2865006923675537 time take: 4.094003915786743
idx: 200, loss: 1.2583352327346802 time take: 4.1184492111206055
idx: 300, loss: 1.507769227027893 time take: 4.100856304168701
idx: 400, loss: 1.0193297863006592 time take: 4.113654375076294
idx: 500, loss: 1.4214119911193848 time take: 4.112135887145996
idx: 600, loss: 1.6740670204162598 time take: 4.080480337142944
idx: 700, loss: 1.6085668802261353 time take: 4.203420400619507
idx: 800, loss: 1.4240754842758179 time take: 4.118756532669067
idx: 900, loss: 1.058815836906433 time take: 4.110085964202881
idx: 1000, loss: 1.484817624092102 time take: 4.102846145629883
idx: 1100, loss: 1.2726942300796509 time take: 4.078997611999512
idx: 1200, loss: 1.447304606437683 time take: 4.075313091278076
idx: 1300, loss: 1.3901137113571167 time take: 4.094705820083618
idx: 1400, loss: 1.263493299484253 time take

idx: 900, loss: 1.1027348041534424 time take: 4.084017992019653
idx: 1000, loss: 1.5103695392608643 time take: 4.075309753417969
idx: 1100, loss: 1.253724217414856 time take: 4.071546792984009
idx: 1200, loss: 1.4333308935165405 time take: 4.080966949462891
idx: 1300, loss: 1.4251809120178223 time take: 4.094465494155884
idx: 1400, loss: 1.304009199142456 time take: 4.242366313934326
idx: 1500, loss: 1.758765459060669 time take: 4.074569940567017
epoch 96 time taken: 63.98118543624878s
Training epoch: 97
idx: 0, loss: 1.5686594247817993 time take: 0.02225804328918457
idx: 100, loss: 1.3043993711471558 time take: 4.098306179046631
idx: 200, loss: 1.3445965051651 time take: 4.102671384811401
idx: 300, loss: 1.3716797828674316 time take: 4.1063759326934814
idx: 400, loss: 1.3724616765975952 time take: 4.074453830718994
idx: 500, loss: 1.3534278869628906 time take: 4.099323034286499
idx: 600, loss: 1.5494563579559326 time take: 4.088167667388916
idx: 700, loss: 1.5972121953964233 time take

idx: 700, loss: 1.8215558528900146 time take: 4.089768648147583
idx: 800, loss: 1.6028302907943726 time take: 4.098402976989746
idx: 900, loss: 1.3640899658203125 time take: 4.105599403381348
idx: 1000, loss: 1.705100417137146 time take: 4.103650808334351
idx: 1100, loss: 1.582532286643982 time take: 4.105262756347656
idx: 1200, loss: 1.8537640571594238 time take: 4.092085838317871
idx: 1300, loss: 1.547794222831726 time take: 4.1060216426849365
idx: 1400, loss: 1.568840503692627 time take: 4.167525768280029
idx: 1500, loss: 1.9892292022705078 time take: 4.234285116195679
epoch 3 time taken: 64.24445867538452s
Training epoch: 4
idx: 0, loss: 1.5796880722045898 time take: 0.02443075180053711
idx: 100, loss: 1.6186330318450928 time take: 4.0908730030059814
idx: 200, loss: 1.413862943649292 time take: 4.120291471481323
idx: 300, loss: 1.7179129123687744 time take: 4.114173173904419
idx: 400, loss: 1.7382628917694092 time take: 4.105002164840698
idx: 500, loss: 1.5780706405639648 time take

idx: 100, loss: 1.3550821542739868 time take: 4.09803032875061
idx: 200, loss: 1.5447969436645508 time take: 4.1024250984191895
idx: 300, loss: 1.686699390411377 time take: 4.090038061141968
idx: 400, loss: 1.5275958776474 time take: 4.081770658493042
idx: 500, loss: 1.3380173444747925 time take: 4.126019716262817
idx: 600, loss: 1.882023811340332 time take: 4.248592853546143
idx: 700, loss: 1.77780282497406 time take: 4.0906982421875
idx: 800, loss: 1.3197827339172363 time take: 4.142865419387817
idx: 900, loss: 1.217590093612671 time take: 4.119837045669556
idx: 1000, loss: 1.5047324895858765 time take: 4.108287811279297
idx: 1100, loss: 1.4834998846054077 time take: 4.10918402671814
idx: 1200, loss: 1.3419859409332275 time take: 4.102942228317261
idx: 1300, loss: 1.2439063787460327 time take: 4.149276971817017
idx: 1400, loss: 1.4010181427001953 time take: 4.109167575836182
idx: 1500, loss: 1.6313986778259277 time take: 4.082230567932129
epoch 11 time taken: 64.33807516098022s
Train

idx: 1100, loss: 1.5838984251022339 time take: 4.083296060562134
idx: 1200, loss: 1.4744807481765747 time take: 4.102509260177612
idx: 1300, loss: 1.430571436882019 time take: 4.21617579460144
idx: 1400, loss: 1.6649515628814697 time take: 4.098855495452881
idx: 1500, loss: 1.809640884399414 time take: 4.098333358764648
epoch 18 time taken: 64.27735781669617s
Training epoch: 19
idx: 0, loss: 1.5876017808914185 time take: 0.02317357063293457
idx: 100, loss: 1.4437731504440308 time take: 4.125580787658691
idx: 200, loss: 1.516552448272705 time take: 4.157597303390503
idx: 300, loss: 1.6061413288116455 time take: 4.095012187957764
idx: 400, loss: 1.1544584035873413 time take: 4.108507871627808
idx: 500, loss: 1.4189293384552002 time take: 4.110816955566406
idx: 600, loss: 1.6400063037872314 time take: 4.121153116226196
idx: 700, loss: 1.6181085109710693 time take: 4.119778394699097
idx: 800, loss: 1.5109663009643555 time take: 4.127675771713257
idx: 900, loss: 1.101135492324829 time take:

idx: 400, loss: 1.3679646253585815 time take: 4.121880054473877
idx: 500, loss: 1.416599988937378 time take: 4.283716201782227
idx: 600, loss: 1.6547847986221313 time take: 4.1161839962005615
idx: 700, loss: 1.526242971420288 time take: 4.113145589828491
idx: 800, loss: 1.3957865238189697 time take: 4.091496706008911
idx: 900, loss: 1.11345374584198 time take: 4.099486589431763
idx: 1000, loss: 1.6035429239273071 time take: 4.099576473236084
idx: 1100, loss: 1.2846940755844116 time take: 4.139935255050659
idx: 1200, loss: 1.6556192636489868 time take: 4.116240739822388
idx: 1300, loss: 1.355257272720337 time take: 4.111727714538574
idx: 1400, loss: 1.3321070671081543 time take: 4.106266021728516
idx: 1500, loss: 1.7699217796325684 time take: 4.119186162948608
epoch 26 time taken: 64.4188756942749s
Training epoch: 27
idx: 0, loss: 1.5430290699005127 time take: 0.022237300872802734
idx: 100, loss: 1.4651904106140137 time take: 4.097591876983643
idx: 200, loss: 1.3345935344696045 time tak

idx: 1400, loss: 1.4680486917495728 time take: 4.080406427383423
idx: 1500, loss: 1.8739246129989624 time take: 4.113465309143066
epoch 33 time taken: 64.19349908828735s
Training epoch: 34
idx: 0, loss: 1.3863991498947144 time take: 0.02409529685974121
idx: 100, loss: 1.3071107864379883 time take: 4.078763008117676
idx: 200, loss: 1.4220590591430664 time take: 4.095707654953003
idx: 300, loss: 1.7927629947662354 time take: 4.097944021224976
idx: 400, loss: 1.1157159805297852 time take: 4.148822784423828
idx: 500, loss: 1.4810492992401123 time take: 4.101367712020874
idx: 600, loss: 1.8527601957321167 time take: 4.0864527225494385
idx: 700, loss: 1.5076544284820557 time take: 4.120278596878052
idx: 800, loss: 1.551591396331787 time take: 4.127541780471802
idx: 900, loss: 1.0157980918884277 time take: 4.113553285598755
idx: 1000, loss: 1.3717520236968994 time take: 4.086718559265137
idx: 1100, loss: 1.387721300125122 time take: 4.250607252120972
idx: 1200, loss: 1.332884669303894 time ta

idx: 700, loss: 1.6480903625488281 time take: 4.108511686325073
idx: 800, loss: 1.3723708391189575 time take: 4.10209846496582
idx: 900, loss: 1.1782196760177612 time take: 4.085527420043945
idx: 1000, loss: 1.2998178005218506 time take: 4.106117010116577
idx: 1100, loss: 1.5691707134246826 time take: 4.104105234146118
idx: 1200, loss: 1.2471956014633179 time take: 4.120527267456055
idx: 1300, loss: 1.3053059577941895 time take: 4.11917781829834
idx: 1400, loss: 1.4002877473831177 time take: 4.093661308288574
idx: 1500, loss: 1.7815016508102417 time take: 4.100042343139648
epoch 41 time taken: 64.21365070343018s
Training epoch: 42
idx: 0, loss: 1.299797534942627 time take: 0.0226743221282959
idx: 100, loss: 1.2379684448242188 time take: 4.103567361831665
idx: 200, loss: 1.371834635734558 time take: 4.118865251541138
idx: 300, loss: 1.5232577323913574 time take: 4.218372821807861
idx: 400, loss: 1.1854037046432495 time take: 4.122204780578613
idx: 500, loss: 1.1973145008087158 time take

idx: 100, loss: 1.424793004989624 time take: 4.103301763534546
idx: 200, loss: 1.2137770652770996 time take: 4.098669052124023
idx: 300, loss: 1.3589634895324707 time take: 4.091562271118164
idx: 400, loss: 1.2612758874893188 time take: 4.10532283782959
idx: 500, loss: 1.3342030048370361 time take: 4.1004602909088135
idx: 600, loss: 1.608097791671753 time take: 4.13907527923584
idx: 700, loss: 1.520241141319275 time take: 4.085975170135498
idx: 800, loss: 1.337775707244873 time take: 4.100553274154663
idx: 900, loss: 1.1053203344345093 time take: 4.093858957290649
idx: 1000, loss: 1.6519451141357422 time take: 4.2368481159210205
idx: 1100, loss: 1.3130204677581787 time take: 4.1125757694244385
idx: 1200, loss: 1.3044310808181763 time take: 4.091359853744507
idx: 1300, loss: 1.1592061519622803 time take: 4.113505840301514
idx: 1400, loss: 1.4221608638763428 time take: 4.0931174755096436
idx: 1500, loss: 1.4776231050491333 time take: 4.123480796813965
epoch 49 time taken: 64.277570009231

idx: 1100, loss: 1.4172343015670776 time take: 4.111794471740723
idx: 1200, loss: 1.3341182470321655 time take: 4.100377559661865
idx: 1300, loss: 1.221361517906189 time take: 4.098172187805176
idx: 1400, loss: 1.2690623998641968 time take: 4.094125986099243
idx: 1500, loss: 1.996870994567871 time take: 4.146177291870117
epoch 56 time taken: 64.22923636436462s
Training epoch: 57
idx: 0, loss: 1.5732126235961914 time take: 0.02275705337524414
idx: 100, loss: 1.3747061491012573 time take: 4.097727298736572
idx: 200, loss: 1.313018560409546 time take: 4.221260070800781
idx: 300, loss: 1.6746244430541992 time take: 4.104984760284424
idx: 400, loss: 1.1155050992965698 time take: 4.10541033744812
idx: 500, loss: 1.4160501956939697 time take: 4.103991508483887
idx: 600, loss: 1.5836143493652344 time take: 4.076968193054199
idx: 700, loss: 1.5362666845321655 time take: 4.085448980331421
idx: 800, loss: 1.4153417348861694 time take: 4.102370023727417
idx: 900, loss: 0.9147648811340332 time take

idx: 400, loss: 1.2753552198410034 time take: 4.102472543716431
idx: 500, loss: 1.294632077217102 time take: 4.0820393562316895
idx: 600, loss: 1.3934688568115234 time take: 4.082754611968994
idx: 700, loss: 1.5601685047149658 time take: 4.106364965438843
idx: 800, loss: 1.4335323572158813 time take: 4.102339267730713
idx: 900, loss: 1.1037274599075317 time take: 4.259050369262695
idx: 1000, loss: 1.4589593410491943 time take: 4.084752082824707
idx: 1100, loss: 1.343356966972351 time take: 4.0998759269714355
idx: 1200, loss: 1.3339025974273682 time take: 4.132692813873291
idx: 1300, loss: 1.2846696376800537 time take: 4.121488571166992
idx: 1400, loss: 1.5326499938964844 time take: 4.1685919761657715
idx: 1500, loss: 1.7895243167877197 time take: 4.142474889755249
epoch 64 time taken: 64.30441641807556s
Training epoch: 65
idx: 0, loss: 1.5037566423416138 time take: 0.022205352783203125
idx: 100, loss: 1.2996655702590942 time take: 4.088551998138428
idx: 200, loss: 1.2211716175079346 ti

idx: 1400, loss: 1.2188926935195923 time take: 4.084561109542847
idx: 1500, loss: 1.4810242652893066 time take: 4.083250045776367
epoch 71 time taken: 64.19235730171204s
Training epoch: 72
idx: 0, loss: 1.3378510475158691 time take: 0.022359371185302734
idx: 100, loss: 1.3452297449111938 time take: 4.259092569351196
idx: 200, loss: 1.3734619617462158 time take: 4.096637487411499
idx: 300, loss: 1.5527361631393433 time take: 4.101419687271118
idx: 400, loss: 1.1521098613739014 time take: 4.100008487701416
idx: 500, loss: 1.3401960134506226 time take: 4.1183342933654785
idx: 600, loss: 1.409091591835022 time take: 4.09293794631958
idx: 700, loss: 1.5150190591812134 time take: 4.102417469024658
idx: 800, loss: 1.5238255262374878 time take: 4.07917332649231
idx: 900, loss: 1.2209625244140625 time take: 4.095714092254639
idx: 1000, loss: 1.2959414720535278 time take: 4.125247478485107
idx: 1100, loss: 1.443795084953308 time take: 4.114696979522705
idx: 1200, loss: 1.3726919889450073 time ta

idx: 700, loss: 1.519424319267273 time take: 4.125104188919067
idx: 800, loss: 1.2703531980514526 time take: 4.2646729946136475
idx: 900, loss: 1.135718822479248 time take: 4.083980083465576
idx: 1000, loss: 1.6362076997756958 time take: 4.086985111236572
idx: 1100, loss: 1.3018251657485962 time take: 4.142906665802002
idx: 1200, loss: 1.3502687215805054 time take: 4.103118181228638
idx: 1300, loss: 1.166398048400879 time take: 4.078049898147583
idx: 1400, loss: 1.2921169996261597 time take: 4.115442276000977
idx: 1500, loss: 1.9110538959503174 time take: 4.112778902053833
epoch 79 time taken: 64.3604347705841s
Training epoch: 80
idx: 0, loss: 1.333293080329895 time take: 0.02221512794494629
idx: 100, loss: 1.3626527786254883 time take: 4.093260765075684
idx: 200, loss: 1.2473862171173096 time take: 4.0919365882873535
idx: 300, loss: 1.6395701169967651 time take: 4.095915079116821
idx: 400, loss: 1.0481903553009033 time take: 4.092740535736084
idx: 500, loss: 1.4515244960784912 time ta

idx: 100, loss: 1.2760744094848633 time take: 4.09189248085022
idx: 200, loss: 1.3040345907211304 time take: 4.085116863250732
idx: 300, loss: 1.5675663948059082 time take: 4.104551076889038
idx: 400, loss: 1.4202741384506226 time take: 4.152747392654419
idx: 500, loss: 1.3985496759414673 time take: 4.092851400375366
idx: 600, loss: 1.6065247058868408 time take: 4.077431678771973
idx: 700, loss: 1.5580533742904663 time take: 4.105807781219482
idx: 800, loss: 1.4226478338241577 time take: 4.101495027542114
idx: 900, loss: 1.082322359085083 time take: 4.107697486877441
idx: 1000, loss: 1.4458897113800049 time take: 4.11318039894104
idx: 1100, loss: 1.4611766338348389 time take: 4.092067718505859
idx: 1200, loss: 1.4937890768051147 time take: 4.097015380859375
idx: 1300, loss: 1.1647557020187378 time take: 4.089705228805542
idx: 1400, loss: 1.2088645696640015 time take: 4.084718942642212
idx: 1500, loss: 1.6264084577560425 time take: 4.198615789413452
epoch 87 time taken: 64.1677229404449

idx: 1100, loss: 1.3782877922058105 time take: 4.090161323547363
idx: 1200, loss: 1.5683337450027466 time take: 4.107475996017456
idx: 1300, loss: 1.3757293224334717 time take: 4.126867771148682
idx: 1400, loss: 1.2717082500457764 time take: 4.129211664199829
idx: 1500, loss: 1.7574912309646606 time take: 4.089621305465698
epoch 94 time taken: 64.16421890258789s
Training epoch: 95
idx: 0, loss: 1.4611996412277222 time take: 0.024432897567749023
idx: 100, loss: 1.3799128532409668 time take: 4.109936952590942
idx: 200, loss: 1.3747881650924683 time take: 4.11156439781189
idx: 300, loss: 1.655493974685669 time take: 4.103376626968384
idx: 400, loss: 1.190886378288269 time take: 4.081382513046265
idx: 500, loss: 1.3663756847381592 time take: 4.107540845870972
idx: 600, loss: 1.549180507659912 time take: 4.14205265045166
idx: 700, loss: 1.5422576665878296 time take: 4.1755454540252686
idx: 800, loss: 1.540257453918457 time take: 4.105756521224976
idx: 900, loss: 0.9602973461151123 time take

{0: {'dataset': 'cifar',
  'FPR@95': '92.51',
  'AUROC': '56.12',
  'AUPR_IN': '56.08',
  'AUPR_OUT': '54.96',
  'CCR_4': '0.03',
  'CCR_3': '0.42',
  'CCR_2': '2.08',
  'CCR_1': '10.78',
  'ACC': '41.22',
  'optimizer_type': 'Adam',
  'activation_function_type': 'relu',
  'postprocessor_type': 'odin'},
 1: {'dataset': 'cifar',
  'FPR@95': '91.57',
  'AUROC': '60.16',
  'AUPR_IN': '61.22',
  'AUPR_OUT': '57.54',
  'CCR_4': '0.20',
  'CCR_3': '0.60',
  'CCR_2': '3.31',
  'CCR_1': '14.30',
  'ACC': '43.02',
  'optimizer_type': 'Adam',
  'activation_function_type': 'relu',
  'postprocessor_type': 'odin'},
 2: {'dataset': 'cifar',
  'FPR@95': '91.37',
  'AUROC': '62.52',
  'AUPR_IN': '62.89',
  'AUPR_OUT': '59.61',
  'CCR_4': '0.05',
  'CCR_3': '0.85',
  'CCR_2': '4.29',
  'CCR_1': '15.80',
  'ACC': '47.17',
  'optimizer_type': 'Adam',
  'activation_function_type': 'relu',
  'postprocessor_type': 'odin'}}

#### Study 2 (b.) Adam + Softplus + odin

In [None]:
config_cifar_adam_softplus_odin = {
    "batch_size": 32,
    "n_classes": 10,
    "dataset_name": "cifar",
    "epochs": 100,
    "version": time.time(),
    "lr": 0.01,
    "momentum": 0.9,
    "weight_decay": 0.0005,
    "optimizer_type": "Adam",
    "activation_function_type": "softplus",
    "network": "resnet50",
    "postprocessor_type": "odin",
    "trials": 3,
    "dataset_type": "cifar",
    "results_dir": "cifar10-study",
    "pretrained": False
}
config_cifar_adam_softplus_odin["data_loaders"] = get_data_loaders(config_cifar_adam_softplus_odin)
run_full_oodn_pipeline(config_cifar_adam_softplus_odin)

Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Running model: models/cifar_resnet50_odin_softplus_Adam_0.pkl...
Getting optimizer for type: Adam...
Training epoch: 0
idx: 0, loss: 2.9400904178619385 time take: 0.022324323654174805
idx: 100, loss: 5.485348701477051 time take: 4.096525430679321
idx: 200, loss: 5.1429057121276855 time take: 4.111398935317993
idx: 300, loss: 3.4952378273010254 time take: 4.126551628112793
idx: 400, loss: 3.251206159591675 time take: 4.083424806594849
idx: 500, loss: 3.5177512168884277 time take: 4.073462724685669
idx: 600, loss: 3.0264832973480225 time take: 4.1066319942474365
idx: 700, loss: 3.4363059997558594 time take: 4.0969626903533936
idx: 800, loss: 2.5099704265594482 time take: 4.234477519989014
idx: 900, loss: 2.5756146907806396 time take: 4.117691516876221
idx: 1000, loss: 2.4096317291259766 time take: 4.1335954666137695
idx: 1100, loss: 2.5305287837982178 time take: 4.09361171722

idx: 700, loss: 2.8071746826171875 time take: 4.108215808868408
idx: 800, loss: 2.7940328121185303 time take: 4.104716062545776
idx: 900, loss: 2.247343063354492 time take: 4.099836111068726
idx: 1000, loss: 1.9820953607559204 time take: 4.095187425613403
idx: 1100, loss: 2.071901798248291 time take: 4.101616859436035
idx: 1200, loss: 1.7930326461791992 time take: 4.12448787689209
idx: 1300, loss: 2.8186168670654297 time take: 4.143683671951294
idx: 1400, loss: 2.2102015018463135 time take: 4.1360859870910645
idx: 1500, loss: 2.989084243774414 time take: 4.23716402053833
epoch 7 time taken: 64.39987754821777s
Training epoch: 8
idx: 0, loss: 2.5929460525512695 time take: 0.022413253784179688
idx: 100, loss: 2.3263602256774902 time take: 4.107846021652222
idx: 200, loss: 1.8463211059570312 time take: 4.108522653579712
idx: 300, loss: 3.3185489177703857 time take: 4.12127947807312
idx: 400, loss: 2.661604881286621 time take: 4.150969505310059
idx: 500, loss: 3.2336692810058594 time take: 

idx: 100, loss: 2.3893938064575195 time take: 4.119708061218262
idx: 200, loss: 2.375619411468506 time take: 4.13898229598999
idx: 300, loss: 2.826690435409546 time take: 4.105683326721191
idx: 400, loss: 2.2055325508117676 time take: 4.134119749069214
idx: 500, loss: 2.241858959197998 time take: 4.130999803543091
idx: 600, loss: 2.884021759033203 time take: 4.235384464263916
idx: 700, loss: 3.8363029956817627 time take: 4.084251403808594
idx: 800, loss: 2.290513515472412 time take: 4.100091218948364
idx: 900, loss: 2.2360012531280518 time take: 4.10809326171875
idx: 1000, loss: 2.456415891647339 time take: 4.1208178997039795
idx: 1100, loss: 2.9646921157836914 time take: 4.1170814037323
idx: 1200, loss: 2.350405693054199 time take: 4.144125699996948
idx: 1300, loss: 2.3812310695648193 time take: 4.128928184509277
idx: 1400, loss: 1.8794490098953247 time take: 4.111865282058716
idx: 1500, loss: 2.818408727645874 time take: 4.113945484161377
epoch 15 time taken: 64.50151300430298s
Train

idx: 1100, loss: 3.1001486778259277 time take: 4.107495307922363
idx: 1200, loss: 1.9159338474273682 time take: 4.121116876602173
idx: 1300, loss: 2.230834484100342 time take: 4.281702518463135
idx: 1400, loss: 2.03924298286438 time take: 4.121699094772339
idx: 1500, loss: 2.6470682621002197 time take: 4.09529447555542
epoch 22 time taken: 64.57038569450378s
Training epoch: 23
idx: 0, loss: 2.8895890712738037 time take: 0.022577524185180664
idx: 100, loss: 1.872208595275879 time take: 4.095221519470215
idx: 200, loss: 2.1307146549224854 time take: 4.101665258407593
idx: 300, loss: 3.1175224781036377 time take: 4.108230113983154
idx: 400, loss: 1.7871814966201782 time take: 4.125839710235596
idx: 500, loss: 3.200544834136963 time take: 4.150827169418335
idx: 600, loss: 2.0987298488616943 time take: 4.119591236114502
idx: 700, loss: 4.107021331787109 time take: 4.1043925285339355
idx: 800, loss: 2.118429660797119 time take: 4.124672174453735
idx: 900, loss: 2.266782522201538 time take: 4

idx: 400, loss: 1.880631923675537 time take: 4.089143753051758
idx: 500, loss: 2.5125842094421387 time take: 4.2229673862457275
idx: 600, loss: 2.9672224521636963 time take: 4.093381881713867
idx: 700, loss: 3.4907662868499756 time take: 4.1004908084869385
idx: 800, loss: 1.7094521522521973 time take: 4.1288909912109375
idx: 900, loss: 1.9909257888793945 time take: 4.107529163360596
idx: 1000, loss: 2.6961286067962646 time take: 4.112977504730225
idx: 1100, loss: 1.714165449142456 time take: 4.09544825553894
idx: 1200, loss: 2.0552055835723877 time take: 4.104825735092163
idx: 1300, loss: 2.3521339893341064 time take: 4.114988565444946
idx: 1400, loss: 1.9739277362823486 time take: 4.146896123886108
idx: 1500, loss: 2.1803512573242188 time take: 4.115297079086304
epoch 30 time taken: 64.35524415969849s
Training epoch: 31
idx: 0, loss: 2.738518476486206 time take: 0.022493362426757812
idx: 100, loss: 2.952077865600586 time take: 4.128525972366333
idx: 200, loss: 1.952338695526123 time t

idx: 1400, loss: 2.198737621307373 time take: 4.089060068130493
idx: 1500, loss: 2.6355690956115723 time take: 4.081958293914795
epoch 37 time taken: 64.35480070114136s
Training epoch: 38
idx: 0, loss: 2.861743211746216 time take: 0.0224001407623291
idx: 100, loss: 2.3018598556518555 time take: 4.088655948638916
idx: 200, loss: 2.1756398677825928 time take: 4.100988149642944
idx: 300, loss: 3.3657402992248535 time take: 4.09671950340271
idx: 400, loss: 2.2030630111694336 time take: 4.120457649230957
idx: 500, loss: 1.734230875968933 time take: 4.09398341178894
idx: 600, loss: 3.052607774734497 time take: 4.131654500961304
idx: 700, loss: 3.6072535514831543 time take: 4.149821758270264
idx: 800, loss: 1.6270735263824463 time take: 4.0988147258758545
idx: 900, loss: 2.0465190410614014 time take: 4.095991611480713
idx: 1000, loss: 2.7318217754364014 time take: 4.11495566368103
idx: 1100, loss: 2.085622787475586 time take: 4.242206335067749
idx: 1200, loss: 2.1195294857025146 time take: 4.

idx: 700, loss: 3.602370023727417 time take: 4.127084016799927
idx: 800, loss: 1.8973655700683594 time take: 4.123646020889282
idx: 900, loss: 2.041583776473999 time take: 4.119169235229492
idx: 1000, loss: 2.1989307403564453 time take: 4.1262123584747314
idx: 1100, loss: 2.2691056728363037 time take: 4.1335790157318115
idx: 1200, loss: 2.013474225997925 time take: 4.140768051147461
idx: 1300, loss: 2.7644107341766357 time take: 4.12987208366394
idx: 1400, loss: 2.271374464035034 time take: 4.1233134269714355
idx: 1500, loss: 3.3051917552948 time take: 4.140595436096191
epoch 45 time taken: 64.65876173973083s
Training epoch: 46
idx: 0, loss: 2.649014472961426 time take: 0.022432565689086914
idx: 100, loss: 2.088911771774292 time take: 4.101572275161743
idx: 200, loss: 2.382035732269287 time take: 4.258214473724365
idx: 300, loss: 2.535179376602173 time take: 4.123858690261841
idx: 400, loss: 2.445646047592163 time take: 4.148158550262451
idx: 500, loss: 2.233325958251953 time take: 4.1

idx: 100, loss: 2.2234644889831543 time take: 4.110321521759033
idx: 200, loss: 1.8447587490081787 time take: 4.169116258621216
idx: 300, loss: 3.543584108352661 time take: 4.115856885910034
idx: 400, loss: 2.252735137939453 time take: 4.10756778717041
idx: 500, loss: 3.326802968978882 time take: 4.121130704879761
idx: 600, loss: 2.7961556911468506 time take: 4.113459348678589
idx: 700, loss: 3.663034677505493 time take: 4.105225324630737
idx: 800, loss: 1.5799851417541504 time take: 4.1276280879974365
idx: 900, loss: 1.9112210273742676 time take: 4.253338098526001
idx: 1000, loss: 2.103248357772827 time take: 4.118608713150024
idx: 1100, loss: 1.968589425086975 time take: 4.112543106079102
idx: 1200, loss: 2.5945942401885986 time take: 4.117408990859985
idx: 1300, loss: 2.068531036376953 time take: 4.133395671844482
idx: 1400, loss: 1.8901727199554443 time take: 4.108774900436401
idx: 1500, loss: 2.7985382080078125 time take: 4.113760948181152
epoch 53 time taken: 64.50909423828125s
T

idx: 1100, loss: 2.153503894805908 time take: 4.143754720687866
idx: 1200, loss: 2.050445556640625 time take: 4.11003565788269
idx: 1300, loss: 2.248945713043213 time take: 4.1731040477752686
idx: 1400, loss: 1.875996470451355 time take: 4.1295459270477295
idx: 1500, loss: 2.2737810611724854 time take: 4.131067276000977
epoch 60 time taken: 64.8074700832367s
Training epoch: 61
idx: 0, loss: 2.4433789253234863 time take: 0.022348642349243164
idx: 100, loss: 1.9884772300720215 time take: 4.155824899673462
idx: 200, loss: 1.8817042112350464 time take: 4.144668817520142
idx: 300, loss: 2.0982484817504883 time take: 4.132617950439453
idx: 400, loss: 2.3696346282958984 time take: 4.139230489730835
idx: 500, loss: 2.0186593532562256 time take: 4.132508039474487
idx: 600, loss: 2.731252670288086 time take: 4.148940324783325
idx: 700, loss: 2.898787260055542 time take: 4.115652561187744
idx: 800, loss: 1.6477309465408325 time take: 4.1111063957214355
idx: 900, loss: 1.7909482717514038 time take

idx: 400, loss: 3.043217897415161 time take: 4.109385251998901
idx: 500, loss: 2.217162609100342 time take: 4.104505300521851
idx: 600, loss: 2.794947862625122 time take: 4.257663011550903
idx: 700, loss: 3.356848955154419 time take: 4.1457109451293945
idx: 800, loss: 1.8626168966293335 time take: 4.120071649551392
idx: 900, loss: 2.5152976512908936 time take: 4.171619892120361
idx: 1000, loss: 2.3391878604888916 time take: 4.144917011260986
idx: 1100, loss: 1.9030312299728394 time take: 4.1293275356292725
idx: 1200, loss: 2.2430484294891357 time take: 4.130168914794922
idx: 1300, loss: 2.3559155464172363 time take: 4.139855861663818
idx: 1400, loss: 2.19649076461792 time take: 4.131213903427124
idx: 1500, loss: 2.692152976989746 time take: 4.144739866256714
epoch 68 time taken: 64.68627142906189s
Training epoch: 69
idx: 0, loss: 2.2218611240386963 time take: 0.022417306900024414
idx: 100, loss: 1.898298740386963 time take: 4.1468424797058105
idx: 200, loss: 2.113057851791382 time take

idx: 1400, loss: 1.990929365158081 time take: 4.138254880905151
idx: 1500, loss: 2.631664514541626 time take: 4.112443685531616
epoch 75 time taken: 64.76096820831299s
Training epoch: 76
idx: 0, loss: 2.102125883102417 time take: 0.022577762603759766
idx: 100, loss: 2.7370262145996094 time take: 4.1522133350372314
idx: 200, loss: 2.134782314300537 time take: 4.119861602783203
idx: 300, loss: 3.6552302837371826 time take: 4.121166467666626
idx: 400, loss: 2.3762624263763428 time take: 4.10661506652832
idx: 500, loss: 2.9749948978424072 time take: 4.132277488708496
idx: 600, loss: 3.044435977935791 time take: 4.116989850997925
idx: 700, loss: 2.469808578491211 time take: 4.119347333908081
idx: 800, loss: 1.6381781101226807 time take: 4.084413766860962
idx: 900, loss: 1.7733861207962036 time take: 4.100072145462036
idx: 1000, loss: 1.718733787536621 time take: 4.107823371887207
idx: 1100, loss: 1.7473633289337158 time take: 4.209038257598877
idx: 1200, loss: 2.595149278640747 time take: 4

idx: 700, loss: 2.612661600112915 time take: 4.091521501541138
idx: 800, loss: 2.0829434394836426 time take: 4.099865913391113
idx: 900, loss: 2.515453577041626 time take: 4.094320774078369
idx: 1000, loss: 2.3880455493927 time take: 4.1323418617248535
idx: 1100, loss: 1.693222999572754 time take: 4.10141134262085
idx: 1200, loss: 1.9837976694107056 time take: 4.083273887634277
idx: 1300, loss: 3.0197913646698 time take: 4.1062328815460205
idx: 1400, loss: 1.9967581033706665 time take: 4.102694988250732
idx: 1500, loss: 3.1266820430755615 time take: 4.109030485153198
epoch 83 time taken: 64.18324422836304s
Training epoch: 84
idx: 0, loss: 2.697432041168213 time take: 0.022053956985473633
idx: 100, loss: 1.75412917137146 time take: 4.078891038894653
idx: 200, loss: 2.5410826206207275 time take: 4.085751533508301
idx: 300, loss: 2.853717565536499 time take: 4.230431079864502
idx: 400, loss: 2.5781524181365967 time take: 4.1088902950286865
idx: 500, loss: 3.136338949203491 time take: 4.08

idx: 100, loss: 2.1674561500549316 time take: 4.09288215637207
idx: 200, loss: 2.438462495803833 time take: 4.101112604141235
idx: 300, loss: 2.837115526199341 time take: 4.105496883392334
idx: 400, loss: 1.471131443977356 time take: 4.172549724578857
idx: 500, loss: 1.8333771228790283 time take: 4.083679437637329
idx: 600, loss: 2.7564666271209717 time take: 4.098332166671753
idx: 700, loss: 4.265981674194336 time take: 4.101928472518921
idx: 800, loss: 1.6334683895111084 time take: 4.103339910507202
idx: 900, loss: 1.4866218566894531 time take: 4.120863676071167
idx: 1000, loss: 2.228166341781616 time take: 4.10866379737854
idx: 1100, loss: 2.2818140983581543 time take: 4.200334072113037
idx: 1200, loss: 1.6838973760604858 time take: 4.079050302505493
idx: 1300, loss: 2.7004942893981934 time take: 4.087135553359985
idx: 1400, loss: 1.6852507591247559 time take: 4.090198755264282
idx: 1500, loss: 1.8014037609100342 time take: 4.085691213607788
epoch 91 time taken: 64.22058749198914s
T

idx: 1100, loss: 2.0081732273101807 time take: 4.120847463607788
idx: 1200, loss: 2.369462251663208 time take: 4.113290071487427
idx: 1300, loss: 2.4151012897491455 time take: 4.165461301803589
idx: 1400, loss: 1.9285134077072144 time take: 4.110636472702026
idx: 1500, loss: 2.796034574508667 time take: 4.118157863616943
epoch 98 time taken: 64.34691333770752s
Training epoch: 99
idx: 0, loss: 2.399913787841797 time take: 0.022266626358032227
idx: 100, loss: 1.7870452404022217 time take: 4.098530530929565
idx: 200, loss: 2.4170544147491455 time take: 4.2479493618011475
idx: 300, loss: 3.19677996635437 time take: 4.099990367889404
idx: 400, loss: 1.5875859260559082 time take: 4.104693651199341
idx: 500, loss: 2.571455717086792 time take: 4.13665246963501
idx: 600, loss: 2.8675878047943115 time take: 4.118646144866943
idx: 700, loss: 3.0254251956939697 time take: 4.092961072921753
idx: 800, loss: 1.8165124654769897 time take: 4.094493389129639
idx: 900, loss: 2.3485748767852783 time take: