In [1]:
import sys
import os
sys.path.append(os.path.abspath('../..')) # include top level package in python path

In [2]:
import torch
from comparison.examples.vae_mnist import VAE_MNIST
from comparison.loss import ELBO, IWAE_loss, CIWAE_loss, PIWAE_loss, DREG_loss
from comparison.metric import IWAE_64, log_px
from comparison.metric import IWAE_metric, CIWAE_metric, PIWAE_metric, sample_ess
from tqdm.notebook import tqdm
import csv




In [3]:
# Change device as needed. 
# CPU is supported on all machines, CUDA for specific GPUs.

# device = 'cpu'
device = 'cuda'

import time
model_dir = './_modeldata/'
idstr = "bernoulli50"
print(idstr)

bernoulli50


In [4]:
import re
import pandas as pd

class MNISTTrainData():
    def __init__(self, model_dir, idstr):
        self.model_dir = model_dir
        self.idstr = idstr
        
    def model_from_filename(self, filename):
        model_str = "^%s-(.+)-r([0-9]+)-([^-]+)-model\\.pt" % self.idstr
        match = re.search(model_str, filename)
        if match == None:
            return None
        loss = match.group(1)
        run_no = int(match.group(2))
        epoch = self.epoch_from_str(match.group(3))

        return (loss, run_no, epoch)
    
    def results_from_filename(self, filename):
        results_str = "^%s-(.+)-r([0-9]+)-results\\.csv" % self.idstr
        match = re.search(results_str, filename)
        if match == None:
            return None
        loss = match.group(1)
        run_no = int(match.group(2))

        return (loss, run_no)
        
    def get_files(self):
        filenames = os.listdir(self.model_dir)
        models = [self.model_from_filename(fl) for fl in filenames]
        results = [self.results_from_filename(fl) for fl in filenames]
        return [*filter(lambda x: x, models)], [*filter(lambda x: x, results)]
        
    def load_all_results(self):
        results = {
            'vae': [],
            'iwae': [],
            'ciwae-05': [],
            'miwae-8-8': [],
            'piwae-8-8': [],
            'dreg-1-64': [],
        }
        
        _, result_files = self.get_files()
        for loss, run_no in result_files:
            data = self.load_results(loss, run_no)
            results[loss].append((run_no, data))
            
        return results
        
    def epoch_str(self, epoch = None):
        return "final" if epoch == None else "e" + str(epoch)
    
    def epoch_from_str(self, s):
        if s == "final":
            return None
        else:
            match = re.search("e([0-9]+)", s)
            if not match:
                raise ValueError
            return int(match.group(1))
            
    def start_prefix(self):
        return self.model_dir + self.idstr + "-start"
    
    def run_prefix(self, loss, run_no):
        return (
            self.model_dir 
            + self.idstr 
            + "-"
            + loss 
            + "-r"
            + str(run_no)
        )
        
    def start_model_str(self):
        return self.start_prefix() + "-model.pt"
        
    def model_str(self, epoch, *args, **kargs):
        return self.run_prefix(*args, **kargs) + "-" + self.epoch_str(epoch)  + "-model.pt"

    def results_str(self, *args, **kargs):
        return self.run_prefix(*args, **kargs) + "-results.csv"
        
    def load_start(self, model):
        return model.load_state_dict(torch.load(self.start_model_str()))
    
    def save_start(self, model):
        torch.save(model.state_dict(), self.start_model_str())
        
    def load_model(self, model, loss, run_no, epoch):
        filename = self.model_str(epoch, loss, run_no)
        return model.load_state_dict(torch.load(filename))
        
    def save_model(self, model, loss, run_no, epoch):
        filename = self.model_str(epoch, loss, run_no)
        torch.save(model.state_dict(), filename)
        
    def load_results(self, loss, run_no):
        filename = self.results_str(loss, run_no)
        return pd.read_csv(filename)
        
    def save_results(self, test_metrics, loss, run_no):
        filename = self.results_str(loss, run_no)
        test_metrics.to_csv(filename)
            
mnist_train_data = MNISTTrainData(model_dir, idstr)

In [5]:
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, TensorDataset

train_dataset = datasets.MNIST('./_mnist', train=True, download=True,
                   transform=transforms.Compose([
                       transforms.ToTensor(),
                       torch.bernoulli
                       
                   ]))

test_dataset = datasets.MNIST('./_mnist', train=False, download=True,
                   transform=transforms.Compose([
                       transforms.ToTensor(),
                       torch.bernoulli
                   ]))


def static_copy(dataset):
    ys = torch.tensor([*map(lambda tpl: tpl[1], dataset)])
    xs = torch.cat([*map(lambda tpl: tpl[0].unsqueeze(0), dataset)], dim=0)
    return TensorDataset(xs, ys)

#train_dataset = static_copy(train_dataset)
#test_datast = static_copy(test_dataset)


train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)
# We set a low batch_size for sampling IWAE with K=5000
test_loader = DataLoader(test_dataset, batch_size=48, shuffle=True)


In [6]:
def train_epoch(model, optims, dataloader, loss_function, label):
    model.train()
    lss = []
    for imgs, _ in tqdm(dataloader, desc="Training Epoch №%s" % label, leave=False):
        imgs = imgs.view(-1, 28*28).to(device)
        losses = loss_function(model, imgs)
        
        if type(losses) is not tuple:
            optim = optims[0]
            loss = -losses
            lss.append(loss.item())
            optim.zero_grad()
            loss.backward()
            optim.step()
        else: #PIWAE
            grads = []
            for idx, (optim, loss) in enumerate(zip(optims, losses)):
                loss = -loss
                for p in model.parameters():
                    p.grad = None
                loss.backward(retain_graph=idx<len(optims)-1)
                grads.append([p.grad for p in optim.param_groups[0]["params"]])

            for idx, optim in enumerate(optims):
                for g, p in zip(grads[idx], optim.param_groups[0]["params"]):
                    #assert(idx == 1 or not torch.equal(p.grad, g))
                    p.grad = g
            
            for optim in optims:
                optim.step()
    print(torch.Tensor(lss).mean())



In [7]:
def test_epoch(model, test_loader, label):
    with torch.no_grad():
        #evaluate metrics
        iwae_losses  = []
        logpx_losses = []
        kl_losses    = []
        

        for imgs, _ in tqdm(test_loader, desc="Test Epoch №%s" % label, leave=False):
            imgs = imgs.view(-1, 28*28).to(device)            
            IWAE_64_loss = IWAE_64(model, imgs)
            logpx_loss = log_px(model, imgs)
            negKL_loss = -(logpx_loss - IWAE_64_loss)
            
            iwae_losses  += IWAE_64_loss.tolist()
            logpx_losses += logpx_loss.tolist()
            kl_losses    += negKL_loss.tolist()
        
        
        test_scores = {
            "iwae-64": torch.tensor(iwae_losses).mean().item(),
            "logpx": torch.tensor(logpx_losses).mean().item(),
            "-kl"    : torch.tensor(kl_losses).mean().item()
        }
        
        return test_scores
        

In [8]:
def train_and_evaluate(
    model, 
    train_loader, 
    test_loader, 
    test,
    run_no
):    
    test_metrics = {
        "epoch"  : [],
        "iwae-64": [],
        "logpx": [],
        "-kl"    : []
    }

    for epoch in range(1, test.no_epochs + 1):
        train_epoch(model, test.optims, train_loader, test.loss_fn, epoch)
        
        # save metrics
        if epoch % test.epochs_per_sample == 0 or epoch == test.no_epochs:
            test_scores = test_epoch(model, test_loader, epoch)
            test_metrics["epoch"].append(epoch)
            test_metrics["iwae-64"].append(test_scores["iwae-64"])
            test_metrics["logpx"].append(test_scores["logpx"])
            test_metrics["-kl"].append(test_scores["-kl"])
            
        # snapshot model and metrics
        if epoch % 12 == 0:
            mnist_train_data.save_model(model, test.name, run_no, epoch)
            mnist_train_data.save_results(pd.DataFrame(test_metrics), test.name, run_no)

    return pd.DataFrame(test_metrics)

In [9]:
def eval_discard(loss_fn, M=1, K=1):
    return lambda model, xs: loss_fn(model(xs, M, K))

lr=3e-4

class LossTest:
    
    def __init__(self, name, loss_fn, no_epochs, epochs_per_sample, no_runs = 1, piwae=False):
        self.name = name
        self.loss_fn = loss_fn
        self.no_epochs = no_epochs
        self.epochs_per_sample = epochs_per_sample
        self.no_runs = no_runs
        self.piwae = piwae
        self.initialize_model()
    
    def initialize_model(self):
        self.model = VAE_MNIST()
        if self.piwae:
            self.optims = [torch.optim.Adam(self.model.encode_params(), lr=lr),
                           torch.optim.Adam(self.model.decode_params(), lr=lr)]
        else:
            self.optims = [torch.optim.Adam(self.model.parameters(), lr=lr)]
            


fig5_tests = [
    #LossTest("iwae", eval_discard(IWAE_loss, M=1, K=64), 10, 1, 1),
    #LossTest("ciwae-05", eval_discard(lambda res: CIWAE_loss(res, 0.5), M=1, K=64), 10, 1, 1),
    #LossTest("piwae-8-8", eval_discard(PIWAE_loss, M=8, K=8), 1000, 24, 1, piwae=True),
    #LossTest("miwae-8-8", eval_discard(IWAE_loss, M=8, K=8), 1000, 24, 1),
    #LossTest("vae", eval_discard(IWAE_loss, M=64, K=1), 10, 1, 1),
    #LossTest("dreg-1-64", eval_discard(DREG_loss, M=1, K=64), 1000, 24, 1, piwae=True)
    LossTest("dreg-1-64", eval_discard(DREG_loss, M=1, K=64), 1000, 24, 1)
]        


#iwae64_results = train_and_evaluate(model, [optim], train_loader, test_loader, iwae64, no_epochs=1)
#iwae05_results = train_and_evaluate(model, [optim], train_loader, test_loader, ciwae05, no_epochs=5)

    

In [None]:
for test in fig5_tests:
    print("Training ", test.name)
    for run_no in range(1, test.no_runs+1):
        print("Run ", run_no)
        test.initialize_model()
        model = test.model.to(device)

        test_metrics = train_and_evaluate(
            model, 
            train_loader, 
            test_loader, 
            test,
            run_no
        )
        
        mnist_train_data.save_model(model, test.name, run_no, epoch = None)
        mnist_train_data.save_results(test_metrics, test.name, run_no)
        model.to('cpu')
        

Training  dreg-1-64
Run  1


HBox(children=(HTML(value='Training Epoch №1'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(225.2647)


HBox(children=(HTML(value='Training Epoch №2'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(162.8230)


HBox(children=(HTML(value='Training Epoch №3'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(146.6158)


HBox(children=(HTML(value='Training Epoch №4'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(137.6466)


HBox(children=(HTML(value='Training Epoch №5'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(131.5900)


HBox(children=(HTML(value='Training Epoch №6'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(127.1650)


HBox(children=(HTML(value='Training Epoch №7'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(123.6376)


HBox(children=(HTML(value='Training Epoch №8'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(120.8400)


HBox(children=(HTML(value='Training Epoch №9'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(118.3838)


HBox(children=(HTML(value='Training Epoch №10'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(116.0602)


HBox(children=(HTML(value='Training Epoch №11'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(114.0974)


HBox(children=(HTML(value='Training Epoch №12'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(112.4476)


HBox(children=(HTML(value='Training Epoch №13'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(111.0946)


HBox(children=(HTML(value='Training Epoch №14'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(109.8222)


HBox(children=(HTML(value='Training Epoch №15'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(108.7076)


HBox(children=(HTML(value='Training Epoch №16'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(107.6376)


HBox(children=(HTML(value='Training Epoch №17'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(106.6490)


HBox(children=(HTML(value='Training Epoch №18'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(105.7922)


HBox(children=(HTML(value='Training Epoch №19'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(105.0119)


HBox(children=(HTML(value='Training Epoch №20'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(104.3133)


HBox(children=(HTML(value='Training Epoch №21'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(103.6432)


HBox(children=(HTML(value='Training Epoch №22'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(103.1781)


HBox(children=(HTML(value='Training Epoch №23'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(102.7003)


HBox(children=(HTML(value='Training Epoch №24'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(102.1721)


HBox(children=(HTML(value='Test Epoch №24'), FloatProgress(value=0.0, max=209.0), HTML(value='')))

HBox(children=(HTML(value='Training Epoch №25'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(101.7969)


HBox(children=(HTML(value='Training Epoch №26'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(101.3537)


HBox(children=(HTML(value='Training Epoch №27'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(101.0343)


HBox(children=(HTML(value='Training Epoch №28'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(100.7016)


HBox(children=(HTML(value='Training Epoch №29'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(100.3532)


HBox(children=(HTML(value='Training Epoch №30'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(100.0419)


HBox(children=(HTML(value='Training Epoch №31'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(99.7753)


HBox(children=(HTML(value='Training Epoch №32'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(99.4508)


HBox(children=(HTML(value='Training Epoch №33'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(99.2271)


HBox(children=(HTML(value='Training Epoch №34'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(98.9975)


HBox(children=(HTML(value='Training Epoch №35'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(98.7818)


HBox(children=(HTML(value='Training Epoch №36'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(98.6060)


HBox(children=(HTML(value='Training Epoch №37'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(98.3886)


HBox(children=(HTML(value='Training Epoch №38'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(98.2176)


HBox(children=(HTML(value='Training Epoch №39'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(98.0668)


HBox(children=(HTML(value='Training Epoch №40'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(97.8819)


HBox(children=(HTML(value='Training Epoch №41'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(97.7195)


HBox(children=(HTML(value='Training Epoch №42'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(97.5691)


HBox(children=(HTML(value='Training Epoch №43'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(97.3908)


HBox(children=(HTML(value='Training Epoch №44'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(97.2307)


HBox(children=(HTML(value='Training Epoch №45'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(97.1021)


HBox(children=(HTML(value='Training Epoch №46'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(96.9700)


HBox(children=(HTML(value='Training Epoch №47'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(96.8500)


HBox(children=(HTML(value='Training Epoch №48'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(96.6889)


HBox(children=(HTML(value='Test Epoch №48'), FloatProgress(value=0.0, max=209.0), HTML(value='')))

HBox(children=(HTML(value='Training Epoch №49'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(96.6067)


HBox(children=(HTML(value='Training Epoch №50'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(96.4691)


HBox(children=(HTML(value='Training Epoch №51'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(96.4108)


HBox(children=(HTML(value='Training Epoch №52'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(96.2288)


HBox(children=(HTML(value='Training Epoch №53'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(96.1746)


HBox(children=(HTML(value='Training Epoch №54'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(96.0344)


HBox(children=(HTML(value='Training Epoch №55'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(95.9200)


HBox(children=(HTML(value='Training Epoch №56'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(95.8525)


HBox(children=(HTML(value='Training Epoch №57'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(95.7883)


HBox(children=(HTML(value='Training Epoch №58'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(95.7035)


HBox(children=(HTML(value='Training Epoch №59'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(95.5964)


HBox(children=(HTML(value='Training Epoch №60'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(95.5911)


HBox(children=(HTML(value='Training Epoch №61'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(95.4277)


HBox(children=(HTML(value='Training Epoch №62'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(95.4074)


HBox(children=(HTML(value='Training Epoch №63'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(95.3817)


HBox(children=(HTML(value='Training Epoch №64'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(95.1966)


HBox(children=(HTML(value='Training Epoch №65'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(95.1294)


HBox(children=(HTML(value='Training Epoch №66'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(95.0512)


HBox(children=(HTML(value='Training Epoch №67'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(95.0031)


HBox(children=(HTML(value='Training Epoch №68'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(94.9617)


HBox(children=(HTML(value='Training Epoch №69'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(94.8802)


HBox(children=(HTML(value='Training Epoch №70'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(94.8854)


HBox(children=(HTML(value='Training Epoch №71'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(94.7584)


HBox(children=(HTML(value='Training Epoch №72'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(94.7345)


HBox(children=(HTML(value='Test Epoch №72'), FloatProgress(value=0.0, max=209.0), HTML(value='')))

HBox(children=(HTML(value='Training Epoch №73'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(94.6692)


HBox(children=(HTML(value='Training Epoch №74'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(94.5913)


HBox(children=(HTML(value='Training Epoch №75'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(94.5838)


HBox(children=(HTML(value='Training Epoch №76'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(94.4803)


HBox(children=(HTML(value='Training Epoch №77'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(94.4544)


HBox(children=(HTML(value='Training Epoch №78'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(94.4064)


HBox(children=(HTML(value='Training Epoch №79'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(94.3598)


HBox(children=(HTML(value='Training Epoch №80'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(94.3298)


HBox(children=(HTML(value='Training Epoch №81'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(94.3237)


HBox(children=(HTML(value='Training Epoch №82'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(94.2069)


HBox(children=(HTML(value='Training Epoch №83'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(94.1809)


HBox(children=(HTML(value='Training Epoch №84'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(94.1128)


HBox(children=(HTML(value='Training Epoch №85'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(94.1132)


HBox(children=(HTML(value='Training Epoch №86'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(94.0492)


HBox(children=(HTML(value='Training Epoch №87'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(94.0010)


HBox(children=(HTML(value='Training Epoch №88'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(93.9971)


HBox(children=(HTML(value='Training Epoch №89'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(93.8978)


HBox(children=(HTML(value='Training Epoch №90'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(93.8940)


HBox(children=(HTML(value='Training Epoch №91'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(93.7991)


HBox(children=(HTML(value='Training Epoch №92'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(93.7722)


HBox(children=(HTML(value='Training Epoch №93'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(93.7541)


HBox(children=(HTML(value='Training Epoch №94'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(93.7207)


HBox(children=(HTML(value='Training Epoch №95'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(93.6874)


HBox(children=(HTML(value='Training Epoch №96'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(93.6372)


HBox(children=(HTML(value='Test Epoch №96'), FloatProgress(value=0.0, max=209.0), HTML(value='')))

HBox(children=(HTML(value='Training Epoch №97'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(93.6343)


HBox(children=(HTML(value='Training Epoch №98'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(93.5675)


HBox(children=(HTML(value='Training Epoch №99'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(93.5441)


HBox(children=(HTML(value='Training Epoch №100'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(93.4830)


HBox(children=(HTML(value='Training Epoch №101'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(93.4746)


HBox(children=(HTML(value='Training Epoch №102'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(93.4031)


HBox(children=(HTML(value='Training Epoch №103'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(93.4503)


HBox(children=(HTML(value='Training Epoch №104'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(93.4114)


HBox(children=(HTML(value='Training Epoch №105'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(93.3279)


HBox(children=(HTML(value='Training Epoch №106'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(93.3106)


HBox(children=(HTML(value='Training Epoch №107'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(93.3128)


HBox(children=(HTML(value='Training Epoch №108'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(93.2500)


HBox(children=(HTML(value='Training Epoch №109'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(93.2392)


HBox(children=(HTML(value='Training Epoch №110'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(93.1924)


HBox(children=(HTML(value='Training Epoch №111'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(93.1332)


HBox(children=(HTML(value='Training Epoch №112'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(93.1098)


HBox(children=(HTML(value='Training Epoch №113'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(93.0962)


HBox(children=(HTML(value='Training Epoch №114'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(93.0873)


HBox(children=(HTML(value='Training Epoch №115'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(93.0472)


HBox(children=(HTML(value='Training Epoch №116'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(93.0337)


HBox(children=(HTML(value='Training Epoch №117'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(93.0328)


HBox(children=(HTML(value='Training Epoch №118'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(92.9495)


HBox(children=(HTML(value='Training Epoch №119'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(92.9797)


HBox(children=(HTML(value='Training Epoch №120'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(92.9155)


HBox(children=(HTML(value='Test Epoch №120'), FloatProgress(value=0.0, max=209.0), HTML(value='')))

HBox(children=(HTML(value='Training Epoch №121'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(92.8792)


HBox(children=(HTML(value='Training Epoch №122'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(92.8891)


HBox(children=(HTML(value='Training Epoch №123'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(92.8999)


HBox(children=(HTML(value='Training Epoch №124'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(92.8683)


HBox(children=(HTML(value='Training Epoch №125'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(92.8155)


HBox(children=(HTML(value='Training Epoch №126'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(92.8546)


HBox(children=(HTML(value='Training Epoch №127'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(92.7541)


HBox(children=(HTML(value='Training Epoch №128'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(92.7148)


HBox(children=(HTML(value='Training Epoch №129'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(92.7258)


HBox(children=(HTML(value='Training Epoch №130'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(92.6878)


HBox(children=(HTML(value='Training Epoch №131'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(92.7114)


HBox(children=(HTML(value='Training Epoch №132'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(92.6292)


HBox(children=(HTML(value='Training Epoch №133'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(92.5968)


HBox(children=(HTML(value='Training Epoch №134'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(92.6126)


HBox(children=(HTML(value='Training Epoch №135'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(92.5679)


HBox(children=(HTML(value='Training Epoch №136'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(92.6099)


HBox(children=(HTML(value='Training Epoch №137'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(92.5487)


HBox(children=(HTML(value='Training Epoch №138'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(92.5013)


HBox(children=(HTML(value='Training Epoch №139'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(92.4791)


HBox(children=(HTML(value='Training Epoch №140'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(92.5388)


HBox(children=(HTML(value='Training Epoch №141'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(92.4298)


HBox(children=(HTML(value='Training Epoch №142'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(92.4355)


HBox(children=(HTML(value='Training Epoch №143'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(92.4943)


HBox(children=(HTML(value='Training Epoch №144'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(92.4253)


HBox(children=(HTML(value='Test Epoch №144'), FloatProgress(value=0.0, max=209.0), HTML(value='')))

HBox(children=(HTML(value='Training Epoch №145'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(92.3661)


HBox(children=(HTML(value='Training Epoch №146'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(92.4170)


HBox(children=(HTML(value='Training Epoch №147'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(92.3871)


HBox(children=(HTML(value='Training Epoch №148'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(92.3203)


HBox(children=(HTML(value='Training Epoch №149'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(92.2948)


HBox(children=(HTML(value='Training Epoch №150'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(92.3064)


HBox(children=(HTML(value='Training Epoch №151'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(92.2733)


HBox(children=(HTML(value='Training Epoch №152'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(92.2280)


HBox(children=(HTML(value='Training Epoch №153'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(92.2867)


HBox(children=(HTML(value='Training Epoch №154'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(92.2122)


HBox(children=(HTML(value='Training Epoch №155'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(92.2371)


HBox(children=(HTML(value='Training Epoch №156'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(92.2067)


HBox(children=(HTML(value='Training Epoch №157'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(92.2003)


HBox(children=(HTML(value='Training Epoch №158'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(92.1743)


HBox(children=(HTML(value='Training Epoch №159'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(92.1583)


HBox(children=(HTML(value='Training Epoch №160'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(92.0905)


HBox(children=(HTML(value='Training Epoch №161'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(92.1227)


HBox(children=(HTML(value='Training Epoch №162'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(92.0993)


HBox(children=(HTML(value='Training Epoch №163'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(92.0588)


HBox(children=(HTML(value='Training Epoch №164'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(92.1041)


HBox(children=(HTML(value='Training Epoch №165'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(92.0695)


HBox(children=(HTML(value='Training Epoch №166'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(92.0155)


HBox(children=(HTML(value='Training Epoch №167'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(91.9773)


HBox(children=(HTML(value='Training Epoch №168'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(92.0419)


HBox(children=(HTML(value='Test Epoch №168'), FloatProgress(value=0.0, max=209.0), HTML(value='')))

HBox(children=(HTML(value='Training Epoch №169'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(92.0506)


HBox(children=(HTML(value='Training Epoch №170'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(91.9946)


HBox(children=(HTML(value='Training Epoch №171'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(91.9788)


HBox(children=(HTML(value='Training Epoch №172'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(91.9958)


HBox(children=(HTML(value='Training Epoch №173'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(91.9161)


HBox(children=(HTML(value='Training Epoch №174'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(91.9848)


HBox(children=(HTML(value='Training Epoch №175'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(91.9313)


HBox(children=(HTML(value='Training Epoch №176'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(91.8746)


HBox(children=(HTML(value='Training Epoch №177'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(91.9232)


HBox(children=(HTML(value='Training Epoch №178'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(91.8450)


HBox(children=(HTML(value='Training Epoch №179'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(91.8915)


HBox(children=(HTML(value='Training Epoch №180'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(91.8538)


HBox(children=(HTML(value='Training Epoch №181'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(91.8133)


HBox(children=(HTML(value='Training Epoch №182'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(91.7913)


HBox(children=(HTML(value='Training Epoch №183'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(91.8254)


HBox(children=(HTML(value='Training Epoch №184'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(91.7761)


HBox(children=(HTML(value='Training Epoch №185'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(91.8106)


HBox(children=(HTML(value='Training Epoch №186'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(91.8366)


HBox(children=(HTML(value='Training Epoch №187'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(91.7921)


HBox(children=(HTML(value='Training Epoch №188'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(91.7244)


HBox(children=(HTML(value='Training Epoch №189'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(91.7541)


HBox(children=(HTML(value='Training Epoch №190'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(91.7531)


HBox(children=(HTML(value='Training Epoch №191'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(91.7437)


HBox(children=(HTML(value='Training Epoch №192'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(91.6987)


HBox(children=(HTML(value='Test Epoch №192'), FloatProgress(value=0.0, max=209.0), HTML(value='')))

HBox(children=(HTML(value='Training Epoch №193'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(91.7151)


HBox(children=(HTML(value='Training Epoch №194'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(91.6431)


HBox(children=(HTML(value='Training Epoch №195'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(91.6729)


HBox(children=(HTML(value='Training Epoch №196'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(91.6282)


HBox(children=(HTML(value='Training Epoch №197'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(91.6760)


HBox(children=(HTML(value='Training Epoch №198'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(91.6226)


HBox(children=(HTML(value='Training Epoch №199'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(91.6243)


HBox(children=(HTML(value='Training Epoch №200'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(91.6058)


HBox(children=(HTML(value='Training Epoch №201'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(91.6077)


HBox(children=(HTML(value='Training Epoch №202'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(91.6123)


HBox(children=(HTML(value='Training Epoch №203'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(91.5803)


HBox(children=(HTML(value='Training Epoch №204'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(91.5824)


HBox(children=(HTML(value='Training Epoch №205'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(91.5626)


HBox(children=(HTML(value='Training Epoch №206'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(91.5621)


HBox(children=(HTML(value='Training Epoch №207'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(91.5885)


HBox(children=(HTML(value='Training Epoch №208'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(91.5609)


HBox(children=(HTML(value='Training Epoch №209'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(91.5055)


HBox(children=(HTML(value='Training Epoch №210'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(91.5088)


HBox(children=(HTML(value='Training Epoch №211'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(91.5106)


HBox(children=(HTML(value='Training Epoch №212'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(91.5088)


HBox(children=(HTML(value='Training Epoch №213'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(91.4777)


HBox(children=(HTML(value='Training Epoch №214'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(91.5008)


HBox(children=(HTML(value='Training Epoch №215'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(91.4970)


HBox(children=(HTML(value='Training Epoch №216'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(91.4789)


HBox(children=(HTML(value='Test Epoch №216'), FloatProgress(value=0.0, max=209.0), HTML(value='')))

HBox(children=(HTML(value='Training Epoch №217'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(91.4418)


HBox(children=(HTML(value='Training Epoch №218'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(91.4430)


HBox(children=(HTML(value='Training Epoch №219'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(91.4159)


HBox(children=(HTML(value='Training Epoch №220'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(91.4050)


HBox(children=(HTML(value='Training Epoch №221'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(91.3923)


HBox(children=(HTML(value='Training Epoch №222'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(91.3942)


HBox(children=(HTML(value='Training Epoch №223'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(91.3961)


HBox(children=(HTML(value='Training Epoch №224'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(91.4063)


HBox(children=(HTML(value='Training Epoch №225'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(91.3500)


HBox(children=(HTML(value='Training Epoch №226'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(91.3524)


HBox(children=(HTML(value='Training Epoch №227'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(91.3176)


HBox(children=(HTML(value='Training Epoch №228'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(91.3428)


HBox(children=(HTML(value='Training Epoch №229'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(91.3497)


HBox(children=(HTML(value='Training Epoch №230'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(91.3752)


HBox(children=(HTML(value='Training Epoch №231'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(91.3458)


HBox(children=(HTML(value='Training Epoch №232'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(91.3242)


HBox(children=(HTML(value='Training Epoch №233'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(91.2621)


HBox(children=(HTML(value='Training Epoch №234'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(91.2825)


HBox(children=(HTML(value='Training Epoch №235'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(91.2700)


HBox(children=(HTML(value='Training Epoch №236'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(91.2847)


HBox(children=(HTML(value='Training Epoch №237'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(91.2487)


HBox(children=(HTML(value='Training Epoch №238'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(91.2631)


HBox(children=(HTML(value='Training Epoch №239'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(91.2609)


HBox(children=(HTML(value='Training Epoch №240'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(91.2253)


HBox(children=(HTML(value='Test Epoch №240'), FloatProgress(value=0.0, max=209.0), HTML(value='')))

HBox(children=(HTML(value='Training Epoch №241'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(91.2325)


HBox(children=(HTML(value='Training Epoch №242'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(91.2139)


HBox(children=(HTML(value='Training Epoch №243'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(91.2288)


HBox(children=(HTML(value='Training Epoch №244'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(91.2263)


HBox(children=(HTML(value='Training Epoch №245'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(91.2102)


HBox(children=(HTML(value='Training Epoch №246'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(91.1870)


HBox(children=(HTML(value='Training Epoch №247'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(91.1644)


HBox(children=(HTML(value='Training Epoch №248'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(91.2077)


HBox(children=(HTML(value='Training Epoch №249'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(91.1442)


HBox(children=(HTML(value='Training Epoch №250'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(91.1698)


HBox(children=(HTML(value='Training Epoch №251'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(91.1675)


HBox(children=(HTML(value='Training Epoch №252'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(91.1462)


HBox(children=(HTML(value='Training Epoch №253'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(91.1365)


HBox(children=(HTML(value='Training Epoch №254'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(91.0853)


HBox(children=(HTML(value='Training Epoch №255'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(91.1025)


HBox(children=(HTML(value='Training Epoch №256'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(91.0740)


HBox(children=(HTML(value='Training Epoch №257'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(91.1309)


HBox(children=(HTML(value='Training Epoch №258'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(91.0607)


HBox(children=(HTML(value='Training Epoch №259'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(91.1515)


HBox(children=(HTML(value='Training Epoch №260'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(91.1036)


HBox(children=(HTML(value='Training Epoch №261'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(91.1116)


HBox(children=(HTML(value='Training Epoch №262'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(91.0435)


HBox(children=(HTML(value='Training Epoch №263'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(91.0857)


HBox(children=(HTML(value='Training Epoch №264'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(91.0681)


HBox(children=(HTML(value='Test Epoch №264'), FloatProgress(value=0.0, max=209.0), HTML(value='')))

HBox(children=(HTML(value='Training Epoch №265'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(91.0534)


HBox(children=(HTML(value='Training Epoch №266'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(91.0775)


HBox(children=(HTML(value='Training Epoch №267'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(91.0099)


HBox(children=(HTML(value='Training Epoch №268'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(91.0419)


HBox(children=(HTML(value='Training Epoch №269'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(91.0254)


HBox(children=(HTML(value='Training Epoch №270'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(91.0694)


HBox(children=(HTML(value='Training Epoch №271'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(90.9901)


HBox(children=(HTML(value='Training Epoch №272'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(91.0054)


HBox(children=(HTML(value='Training Epoch №273'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(91.0542)


HBox(children=(HTML(value='Training Epoch №274'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(91.0433)


HBox(children=(HTML(value='Training Epoch №275'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(90.9317)


HBox(children=(HTML(value='Training Epoch №276'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(91.0067)


HBox(children=(HTML(value='Training Epoch №277'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(90.9936)


HBox(children=(HTML(value='Training Epoch №278'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(91.0161)


HBox(children=(HTML(value='Training Epoch №279'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(90.9741)


HBox(children=(HTML(value='Training Epoch №280'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(90.9485)


HBox(children=(HTML(value='Training Epoch №281'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(90.9478)


HBox(children=(HTML(value='Training Epoch №282'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(90.9701)


HBox(children=(HTML(value='Training Epoch №283'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(90.9686)


HBox(children=(HTML(value='Training Epoch №284'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(90.9778)


HBox(children=(HTML(value='Training Epoch №285'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(90.9150)


HBox(children=(HTML(value='Training Epoch №286'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(90.9413)


HBox(children=(HTML(value='Training Epoch №287'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(90.9273)


HBox(children=(HTML(value='Training Epoch №288'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(90.9324)


HBox(children=(HTML(value='Test Epoch №288'), FloatProgress(value=0.0, max=209.0), HTML(value='')))

HBox(children=(HTML(value='Training Epoch №289'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(90.9435)


HBox(children=(HTML(value='Training Epoch №290'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(90.8776)


HBox(children=(HTML(value='Training Epoch №291'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(90.8454)


HBox(children=(HTML(value='Training Epoch №292'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(90.8881)


HBox(children=(HTML(value='Training Epoch №293'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(90.8798)


HBox(children=(HTML(value='Training Epoch №294'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(90.8291)


HBox(children=(HTML(value='Training Epoch №295'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(90.8217)


HBox(children=(HTML(value='Training Epoch №296'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(90.8633)


HBox(children=(HTML(value='Training Epoch №297'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(90.8016)


HBox(children=(HTML(value='Training Epoch №298'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(90.8008)


HBox(children=(HTML(value='Training Epoch №299'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(90.8136)


HBox(children=(HTML(value='Training Epoch №300'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(90.8337)


HBox(children=(HTML(value='Training Epoch №301'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(90.8362)


HBox(children=(HTML(value='Training Epoch №302'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(90.8554)


HBox(children=(HTML(value='Training Epoch №303'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(90.8063)


HBox(children=(HTML(value='Training Epoch №304'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(90.8134)


HBox(children=(HTML(value='Training Epoch №305'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(90.7781)


HBox(children=(HTML(value='Training Epoch №306'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(90.7914)


HBox(children=(HTML(value='Training Epoch №307'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(90.7979)


HBox(children=(HTML(value='Training Epoch №308'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(90.7984)


HBox(children=(HTML(value='Training Epoch №309'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(90.7626)


HBox(children=(HTML(value='Training Epoch №310'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(90.7772)


HBox(children=(HTML(value='Training Epoch №311'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(90.7086)


HBox(children=(HTML(value='Training Epoch №312'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(90.7738)


HBox(children=(HTML(value='Test Epoch №312'), FloatProgress(value=0.0, max=209.0), HTML(value='')))

HBox(children=(HTML(value='Training Epoch №313'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(90.7433)


HBox(children=(HTML(value='Training Epoch №314'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(90.7840)


HBox(children=(HTML(value='Training Epoch №315'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(90.7556)


HBox(children=(HTML(value='Training Epoch №316'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(90.7132)


HBox(children=(HTML(value='Training Epoch №317'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(90.7114)


HBox(children=(HTML(value='Training Epoch №318'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(90.6785)


HBox(children=(HTML(value='Training Epoch №319'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(90.7225)


HBox(children=(HTML(value='Training Epoch №320'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(90.7097)


HBox(children=(HTML(value='Training Epoch №321'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(90.7385)


HBox(children=(HTML(value='Training Epoch №322'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(90.7431)


HBox(children=(HTML(value='Training Epoch №323'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(90.7078)


HBox(children=(HTML(value='Training Epoch №324'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(90.7343)


HBox(children=(HTML(value='Training Epoch №325'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(90.6505)


HBox(children=(HTML(value='Training Epoch №326'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(90.6762)


HBox(children=(HTML(value='Training Epoch №327'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(90.6958)


HBox(children=(HTML(value='Training Epoch №328'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(90.7009)


HBox(children=(HTML(value='Training Epoch №329'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(90.6644)


HBox(children=(HTML(value='Training Epoch №330'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(90.6252)


HBox(children=(HTML(value='Training Epoch №331'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(90.6660)


HBox(children=(HTML(value='Training Epoch №332'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(90.6450)


HBox(children=(HTML(value='Training Epoch №333'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(90.6270)


HBox(children=(HTML(value='Training Epoch №334'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(90.6438)


HBox(children=(HTML(value='Training Epoch №335'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(90.6107)


HBox(children=(HTML(value='Training Epoch №336'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(90.6591)


HBox(children=(HTML(value='Test Epoch №336'), FloatProgress(value=0.0, max=209.0), HTML(value='')))

HBox(children=(HTML(value='Training Epoch №337'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(90.5978)


HBox(children=(HTML(value='Training Epoch №338'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(90.6072)


HBox(children=(HTML(value='Training Epoch №339'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(90.5849)


HBox(children=(HTML(value='Training Epoch №340'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(90.6117)


HBox(children=(HTML(value='Training Epoch №341'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(90.5921)


HBox(children=(HTML(value='Training Epoch №342'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(90.6129)


HBox(children=(HTML(value='Training Epoch №343'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(90.6231)


HBox(children=(HTML(value='Training Epoch №344'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(90.5312)


HBox(children=(HTML(value='Training Epoch №345'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(90.6209)


HBox(children=(HTML(value='Training Epoch №346'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(90.5157)


HBox(children=(HTML(value='Training Epoch №347'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(90.5677)


HBox(children=(HTML(value='Training Epoch №348'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(90.6168)


HBox(children=(HTML(value='Training Epoch №349'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(90.5936)


HBox(children=(HTML(value='Training Epoch №350'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(90.5861)


HBox(children=(HTML(value='Training Epoch №351'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(90.5235)


HBox(children=(HTML(value='Training Epoch №352'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(90.5187)


HBox(children=(HTML(value='Training Epoch №353'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(90.5611)


HBox(children=(HTML(value='Training Epoch №354'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(90.5714)


HBox(children=(HTML(value='Training Epoch №355'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(90.5479)


HBox(children=(HTML(value='Training Epoch №356'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(90.5308)


HBox(children=(HTML(value='Training Epoch №357'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(90.5403)


HBox(children=(HTML(value='Training Epoch №358'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(90.5924)


HBox(children=(HTML(value='Training Epoch №359'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(90.5337)


HBox(children=(HTML(value='Training Epoch №360'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(90.5379)


HBox(children=(HTML(value='Test Epoch №360'), FloatProgress(value=0.0, max=209.0), HTML(value='')))

HBox(children=(HTML(value='Training Epoch №361'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(90.5337)


HBox(children=(HTML(value='Training Epoch №362'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(90.5099)


HBox(children=(HTML(value='Training Epoch №363'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(90.5075)


HBox(children=(HTML(value='Training Epoch №364'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(90.4862)


HBox(children=(HTML(value='Training Epoch №365'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(90.5075)


HBox(children=(HTML(value='Training Epoch №366'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(90.4618)


HBox(children=(HTML(value='Training Epoch №367'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(90.4272)


HBox(children=(HTML(value='Training Epoch №368'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(90.4686)


HBox(children=(HTML(value='Training Epoch №369'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(90.4508)


HBox(children=(HTML(value='Training Epoch №370'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(90.5241)


HBox(children=(HTML(value='Training Epoch №371'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(90.5229)


HBox(children=(HTML(value='Training Epoch №372'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(90.5390)


HBox(children=(HTML(value='Training Epoch №373'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(90.4662)


HBox(children=(HTML(value='Training Epoch №374'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(90.4478)


HBox(children=(HTML(value='Training Epoch №375'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(90.4728)


HBox(children=(HTML(value='Training Epoch №376'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(90.4532)


HBox(children=(HTML(value='Training Epoch №377'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(90.4580)


HBox(children=(HTML(value='Training Epoch №378'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(90.4630)


HBox(children=(HTML(value='Training Epoch №379'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(90.4211)


HBox(children=(HTML(value='Training Epoch №380'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(90.4510)


HBox(children=(HTML(value='Training Epoch №381'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(90.4486)


HBox(children=(HTML(value='Training Epoch №382'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(90.4728)


HBox(children=(HTML(value='Training Epoch №383'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(90.4615)


HBox(children=(HTML(value='Training Epoch №384'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(90.4503)


HBox(children=(HTML(value='Test Epoch №384'), FloatProgress(value=0.0, max=209.0), HTML(value='')))

HBox(children=(HTML(value='Training Epoch №385'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(90.4717)


HBox(children=(HTML(value='Training Epoch №386'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(90.4323)


HBox(children=(HTML(value='Training Epoch №387'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(90.3713)


HBox(children=(HTML(value='Training Epoch №388'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(90.3570)


HBox(children=(HTML(value='Training Epoch №389'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(90.4252)


HBox(children=(HTML(value='Training Epoch №390'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(90.4201)


HBox(children=(HTML(value='Training Epoch №391'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(90.3500)


HBox(children=(HTML(value='Training Epoch №392'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(90.3844)


HBox(children=(HTML(value='Training Epoch №393'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(90.3538)


HBox(children=(HTML(value='Training Epoch №394'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(90.4126)


HBox(children=(HTML(value='Training Epoch №395'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(90.3794)


HBox(children=(HTML(value='Training Epoch №396'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(90.3660)


HBox(children=(HTML(value='Training Epoch №397'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(90.3913)


HBox(children=(HTML(value='Training Epoch №398'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(90.3882)


HBox(children=(HTML(value='Training Epoch №399'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(90.3890)


HBox(children=(HTML(value='Training Epoch №400'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(90.3688)


HBox(children=(HTML(value='Training Epoch №401'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(90.3644)


HBox(children=(HTML(value='Training Epoch №402'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(90.3265)


HBox(children=(HTML(value='Training Epoch №403'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(90.3522)


HBox(children=(HTML(value='Training Epoch №404'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(90.3323)


HBox(children=(HTML(value='Training Epoch №405'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(90.3253)


HBox(children=(HTML(value='Training Epoch №406'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(90.3234)


HBox(children=(HTML(value='Training Epoch №407'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(90.3451)


HBox(children=(HTML(value='Training Epoch №408'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(90.3222)


HBox(children=(HTML(value='Test Epoch №408'), FloatProgress(value=0.0, max=209.0), HTML(value='')))

HBox(children=(HTML(value='Training Epoch №409'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(90.3066)


HBox(children=(HTML(value='Training Epoch №410'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(90.3348)


HBox(children=(HTML(value='Training Epoch №411'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(90.3324)


HBox(children=(HTML(value='Training Epoch №412'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(90.2702)


HBox(children=(HTML(value='Training Epoch №413'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(90.2661)


HBox(children=(HTML(value='Training Epoch №414'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(90.3125)


HBox(children=(HTML(value='Training Epoch №415'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(90.2579)


HBox(children=(HTML(value='Training Epoch №416'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(90.3091)


HBox(children=(HTML(value='Training Epoch №417'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(90.2737)


HBox(children=(HTML(value='Training Epoch №418'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(90.3183)


HBox(children=(HTML(value='Training Epoch №419'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(90.3186)


HBox(children=(HTML(value='Training Epoch №420'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(90.2811)


HBox(children=(HTML(value='Training Epoch №421'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(90.2923)


HBox(children=(HTML(value='Training Epoch №422'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(90.3093)


HBox(children=(HTML(value='Training Epoch №423'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(90.2814)


HBox(children=(HTML(value='Training Epoch №424'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(90.2662)


HBox(children=(HTML(value='Training Epoch №425'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(90.2672)


HBox(children=(HTML(value='Training Epoch №426'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(90.2867)


HBox(children=(HTML(value='Training Epoch №427'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(90.2690)


HBox(children=(HTML(value='Training Epoch №428'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(90.3058)


HBox(children=(HTML(value='Training Epoch №429'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(90.3086)


HBox(children=(HTML(value='Training Epoch №430'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(90.2122)


HBox(children=(HTML(value='Training Epoch №431'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(90.3213)


HBox(children=(HTML(value='Training Epoch №432'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(90.2487)


HBox(children=(HTML(value='Test Epoch №432'), FloatProgress(value=0.0, max=209.0), HTML(value='')))

HBox(children=(HTML(value='Training Epoch №433'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(90.2575)


HBox(children=(HTML(value='Training Epoch №434'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(90.2787)


HBox(children=(HTML(value='Training Epoch №435'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(90.2293)


HBox(children=(HTML(value='Training Epoch №436'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(90.2677)


HBox(children=(HTML(value='Training Epoch №437'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(90.2709)


HBox(children=(HTML(value='Training Epoch №438'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(90.2337)


HBox(children=(HTML(value='Training Epoch №439'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(90.2057)


HBox(children=(HTML(value='Training Epoch №440'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(90.2020)


HBox(children=(HTML(value='Training Epoch №441'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(90.2726)


HBox(children=(HTML(value='Training Epoch №442'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(90.1917)


HBox(children=(HTML(value='Training Epoch №443'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(90.2021)


HBox(children=(HTML(value='Training Epoch №444'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(90.1694)


HBox(children=(HTML(value='Training Epoch №445'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(90.2523)


HBox(children=(HTML(value='Training Epoch №446'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(90.1865)


HBox(children=(HTML(value='Training Epoch №447'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(90.1485)


HBox(children=(HTML(value='Training Epoch №448'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(90.1824)


HBox(children=(HTML(value='Training Epoch №449'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(90.1679)


HBox(children=(HTML(value='Training Epoch №450'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(90.1542)


HBox(children=(HTML(value='Training Epoch №451'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(90.1708)


HBox(children=(HTML(value='Training Epoch №452'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(90.1832)


HBox(children=(HTML(value='Training Epoch №453'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(90.1433)


HBox(children=(HTML(value='Training Epoch №454'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(90.1591)


HBox(children=(HTML(value='Training Epoch №455'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(90.1827)


HBox(children=(HTML(value='Training Epoch №456'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(90.1507)


HBox(children=(HTML(value='Test Epoch №456'), FloatProgress(value=0.0, max=209.0), HTML(value='')))

HBox(children=(HTML(value='Training Epoch №457'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(90.1771)


HBox(children=(HTML(value='Training Epoch №458'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(90.1905)


HBox(children=(HTML(value='Training Epoch №459'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(90.1648)


HBox(children=(HTML(value='Training Epoch №460'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(90.1745)


HBox(children=(HTML(value='Training Epoch №461'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(90.1382)


HBox(children=(HTML(value='Training Epoch №462'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(90.0494)


HBox(children=(HTML(value='Training Epoch №463'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(90.1159)


HBox(children=(HTML(value='Training Epoch №464'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(90.1326)


HBox(children=(HTML(value='Training Epoch №465'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(90.1434)


HBox(children=(HTML(value='Training Epoch №466'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(90.1332)


HBox(children=(HTML(value='Training Epoch №467'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(90.1112)


HBox(children=(HTML(value='Training Epoch №468'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(90.0986)


HBox(children=(HTML(value='Training Epoch №469'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(90.0842)


HBox(children=(HTML(value='Training Epoch №470'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(90.1034)


HBox(children=(HTML(value='Training Epoch №471'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(90.0987)


HBox(children=(HTML(value='Training Epoch №472'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(90.0918)


HBox(children=(HTML(value='Training Epoch №473'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(90.1061)


HBox(children=(HTML(value='Training Epoch №474'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(90.0918)


HBox(children=(HTML(value='Training Epoch №475'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(90.1317)


HBox(children=(HTML(value='Training Epoch №476'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(90.1156)


HBox(children=(HTML(value='Training Epoch №477'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(90.1227)


HBox(children=(HTML(value='Training Epoch №478'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(90.1499)


HBox(children=(HTML(value='Training Epoch №479'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(90.0959)


HBox(children=(HTML(value='Training Epoch №480'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(90.0912)


HBox(children=(HTML(value='Test Epoch №480'), FloatProgress(value=0.0, max=209.0), HTML(value='')))

HBox(children=(HTML(value='Training Epoch №481'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(90.0937)


HBox(children=(HTML(value='Training Epoch №482'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(90.0756)


HBox(children=(HTML(value='Training Epoch №483'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(90.1001)


HBox(children=(HTML(value='Training Epoch №484'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(90.0856)


HBox(children=(HTML(value='Training Epoch №485'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(90.0906)


HBox(children=(HTML(value='Training Epoch №486'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(90.0866)


HBox(children=(HTML(value='Training Epoch №487'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(90.0519)


HBox(children=(HTML(value='Training Epoch №488'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(90.0073)


HBox(children=(HTML(value='Training Epoch №489'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(90.1070)


HBox(children=(HTML(value='Training Epoch №490'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(90.0542)


HBox(children=(HTML(value='Training Epoch №491'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(90.0675)


HBox(children=(HTML(value='Training Epoch №492'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(90.0816)


HBox(children=(HTML(value='Training Epoch №493'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(90.0495)


HBox(children=(HTML(value='Training Epoch №494'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(90.0213)


HBox(children=(HTML(value='Training Epoch №495'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(90.0270)


HBox(children=(HTML(value='Training Epoch №496'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(90.0373)


HBox(children=(HTML(value='Training Epoch №497'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(90.0385)


HBox(children=(HTML(value='Training Epoch №498'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(90.0631)


HBox(children=(HTML(value='Training Epoch №499'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(90.0195)


HBox(children=(HTML(value='Training Epoch №500'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(90.0138)


HBox(children=(HTML(value='Training Epoch №501'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(90.0528)


HBox(children=(HTML(value='Training Epoch №502'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(90.0124)


HBox(children=(HTML(value='Training Epoch №503'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(90.0216)


HBox(children=(HTML(value='Training Epoch №504'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(90.0168)


HBox(children=(HTML(value='Test Epoch №504'), FloatProgress(value=0.0, max=209.0), HTML(value='')))

HBox(children=(HTML(value='Training Epoch №505'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(89.9811)


HBox(children=(HTML(value='Training Epoch №506'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(89.9749)


HBox(children=(HTML(value='Training Epoch №507'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(89.9650)


HBox(children=(HTML(value='Training Epoch №508'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(89.9763)


HBox(children=(HTML(value='Training Epoch №509'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(90.0115)


HBox(children=(HTML(value='Training Epoch №510'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(89.9817)


HBox(children=(HTML(value='Training Epoch №511'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(90.0893)


HBox(children=(HTML(value='Training Epoch №512'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(89.9833)


HBox(children=(HTML(value='Training Epoch №513'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(90.0561)


HBox(children=(HTML(value='Training Epoch №514'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(89.9927)


HBox(children=(HTML(value='Training Epoch №515'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(89.9710)


HBox(children=(HTML(value='Training Epoch №516'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(89.9750)


HBox(children=(HTML(value='Training Epoch №517'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(90.0060)


HBox(children=(HTML(value='Training Epoch №518'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(89.9855)


HBox(children=(HTML(value='Training Epoch №519'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(89.9878)


HBox(children=(HTML(value='Training Epoch №520'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(90.0183)


HBox(children=(HTML(value='Training Epoch №521'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(89.9796)


HBox(children=(HTML(value='Training Epoch №522'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(89.9941)


HBox(children=(HTML(value='Training Epoch №523'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(89.9827)


HBox(children=(HTML(value='Training Epoch №524'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(90.0010)


HBox(children=(HTML(value='Training Epoch №525'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(89.9898)


HBox(children=(HTML(value='Training Epoch №526'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(89.9382)


HBox(children=(HTML(value='Training Epoch №527'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(89.9537)


HBox(children=(HTML(value='Training Epoch №528'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(89.9312)


HBox(children=(HTML(value='Test Epoch №528'), FloatProgress(value=0.0, max=209.0), HTML(value='')))

HBox(children=(HTML(value='Training Epoch №529'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(89.9476)


HBox(children=(HTML(value='Training Epoch №530'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(89.9585)


HBox(children=(HTML(value='Training Epoch №531'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(89.9666)


HBox(children=(HTML(value='Training Epoch №532'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(89.9529)


HBox(children=(HTML(value='Training Epoch №533'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(89.9695)


HBox(children=(HTML(value='Training Epoch №534'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(89.9374)


HBox(children=(HTML(value='Training Epoch №535'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(89.9301)


HBox(children=(HTML(value='Training Epoch №536'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(89.9291)


HBox(children=(HTML(value='Training Epoch №537'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(89.9305)


HBox(children=(HTML(value='Training Epoch №538'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(89.9055)


HBox(children=(HTML(value='Training Epoch №539'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(89.9665)


HBox(children=(HTML(value='Training Epoch №540'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(89.9473)


HBox(children=(HTML(value='Training Epoch №541'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(89.9448)


HBox(children=(HTML(value='Training Epoch №542'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(89.9029)


HBox(children=(HTML(value='Training Epoch №543'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(89.9061)


HBox(children=(HTML(value='Training Epoch №544'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(89.9284)


HBox(children=(HTML(value='Training Epoch №545'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(89.9110)


HBox(children=(HTML(value='Training Epoch №546'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(89.9072)


HBox(children=(HTML(value='Training Epoch №547'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(89.8988)


HBox(children=(HTML(value='Training Epoch №548'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(89.8408)


HBox(children=(HTML(value='Training Epoch №549'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(89.9306)


HBox(children=(HTML(value='Training Epoch №550'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(89.9277)


HBox(children=(HTML(value='Training Epoch №551'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(89.9091)


HBox(children=(HTML(value='Training Epoch №552'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(89.8883)


HBox(children=(HTML(value='Test Epoch №552'), FloatProgress(value=0.0, max=209.0), HTML(value='')))

HBox(children=(HTML(value='Training Epoch №553'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(89.8914)


HBox(children=(HTML(value='Training Epoch №554'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(89.9026)


HBox(children=(HTML(value='Training Epoch №555'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(89.8906)


HBox(children=(HTML(value='Training Epoch №556'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(89.9033)


HBox(children=(HTML(value='Training Epoch №557'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(89.8715)


HBox(children=(HTML(value='Training Epoch №558'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(89.9129)


HBox(children=(HTML(value='Training Epoch №559'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(89.8966)


HBox(children=(HTML(value='Training Epoch №560'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(89.8802)


HBox(children=(HTML(value='Training Epoch №561'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(89.9161)


HBox(children=(HTML(value='Training Epoch №562'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(89.8835)


HBox(children=(HTML(value='Training Epoch №563'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(89.8794)


HBox(children=(HTML(value='Training Epoch №564'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(89.8670)


HBox(children=(HTML(value='Training Epoch №565'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(89.8570)


HBox(children=(HTML(value='Training Epoch №566'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(89.8719)


HBox(children=(HTML(value='Training Epoch №567'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(89.9185)


HBox(children=(HTML(value='Training Epoch №568'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(89.8128)


HBox(children=(HTML(value='Training Epoch №569'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(89.8399)


HBox(children=(HTML(value='Training Epoch №570'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(89.8497)


HBox(children=(HTML(value='Training Epoch №571'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(89.8531)


HBox(children=(HTML(value='Training Epoch №572'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(89.8443)


HBox(children=(HTML(value='Training Epoch №573'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(89.9019)


HBox(children=(HTML(value='Training Epoch №574'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(89.8459)


HBox(children=(HTML(value='Training Epoch №575'), FloatProgress(value=0.0, max=469.0), HTML(value='')))

In [None]:
def compute_ess(model, T=100):
    effective_sample_sizes = []

    for imgs, _ in tqdm(test_loader, desc="Computing ESS", leave=False):
        with torch.no_grad():
            imgs = imgs.view(-1, 28*28).to(device)
            esss = sample_ess(model, imgs, T=T)

            for ess in list(esss):
                effective_sample_sizes.append(ess.item())
    return effective_sample_sizes

ess_samples = []

for test in fig5_tests:
     ess_samples.append(compute_ess(test.model.to(device), 1000))
    

In [None]:
import matplotlib.pyplot as plt
# Create a figure instance
fig = plt.figure()

# Create an axes instance
ax = fig.add_axes([0,0,1,1])

# Create the boxplot
bp = ax.violinplot(ess_samples, showmedians=True, points=100)
plt.ylim(0,40e-3)
plt.show()

In [None]:
#test_results = train_and_evaluate(model, [optim], train_loader, test_loader, eval_discard(IWAE_loss, M=1, K=64), 6, 10, model_dir+"test")
 
#model.load_state_dict(torch.load(model_dir+'20220404-192054MIWAE8_8_run1_model_epoch6_'))


import matplotlib.pyplot as plt

plt.figure(figsize=(4,26))

for imgs, _ in tqdm(test_loader):
    for i in range(imgs.shape[0]):
        if i >= 10:
            break
        img = imgs[i].view(28 * 28).to(device)
        recon = model.reconstruct(img)
        ax = plt.subplot(10, 2, 2 * i + 1)
        ax.imshow(img.view(28, 28).clone().detach().cpu())
        plt.title("Original")
        ax = plt.subplot(10, 2, 2 * i + 2)
        ax.imshow(recon.view(28, 28).clone().detach().cpu())
        plt.title("Reconstructed")
    break
plt.show()

In [None]:
mnist_train_data.load_all_results()

In [None]:
def plot_metric(ress: dict, col: str, ax):
    lines = []
    for loss in ress.keys():
        table = ress[loss][0][1]
        epochs = table['epoch']
        metric = table[col]
        line, = ax.plot(epochs, metric, label=loss)
        lines.append(line)
    ax.legend(handles=lines)

In [None]:
import matplotlib.pyplot as plt
fig, ax = plt.subplots(3)
fig.set_figheight(12)
fig.set_figwidth(7)
plot_metric(mnist_train_data.load_all_results(), 'iwae-64', ax[0])
plot_metric(mnist_train_data.load_all_results(), 'logpx', ax[1])
plot_metric(mnist_train_data.load_all_results(), '-kl', ax[2])
ax[0].set_xscale('log')
ax[1].set_xscale('log')
fig.suptitle('Convergence of evaluation metrics on the test set over time.')
ax[0].set_ylabel('IWAE-64')
ax[1].set_ylabel('log(p(x))')
ax[2].set_ylabel('-KL')
ax[0].set_xlabel('Epoch')
ax[1].set_xlabel('Epoch')
ax[2].set_xlabel('Epoch')
ax[0].set_xlim(24,1000)
ax[1].set_xlim(24,1000)
ax[2].set_xlim(24,1000)
fig.tight_layout()