In [3]:
%run Archpool.ipynb
%run Argparser.ipynb
%run Topo_treatment.ipynb
%run Net.ipynb

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision.utils import save_image

import model
from dataset import *
from utils import check_dir
from tqdm import tqdm
import numpy as np
from functools import reduce

import os
import json
import argparse

import sys
import json
sys.path.insert(0, './persis_lib_cpp')
from persis_homo_optimal import *

torch_device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


def parse_args():
    '''parser = argparse.ArgumentParser()
    parser.add_argument('config', help='configuration file path')
    opt = parser.parse_args()
    with open(opt.config, 'r') as inf:
        config = json.load(inf)'''
    
    f = open('example.json')
    config = json.load(f)

    try:
        if config['output_path'][-1] != '/':
            config['output_path'] += '/'
        if config['train_data_path'][-1] != '/':
            config['train_data_path'] += '/'
        if config['val_data_path'][-1] != '/':
            config['val_data_path'] += '/'
    except KeyError as err:
        print(f'{opt.config}: Unspecified path {err}')
        exit(1)
    return config


def initialize_network(config):
    network = {}
    random_seed = 0
    try:
        config['resume']
        random_seed = config['random_seed']
    except KeyError:
        config['resume'] = False
    torch.manual_seed(random_seed)
    if (config['resume'] and os.path.isfile(config['resume'])):
        confout = config['output_path'] + config['name'] + '/'
        network['resume'] = True
        checkpoint = torch.load(config['resume'], map_location=torch_device)
        network['epoch_start'] = checkpoint['epoch'] + \
            1 if checkpoint['output_dir'] == confout else 0
        network['epoch_end'] = config['epoch'] or checkpoint['epoch_end']
        network['output_dir'] = confout
        network['checkpoint_dir'] = checkpoint['checkpoint_dir']
        network['learning_rate'] = checkpoint['learning_rate']
        network['train_data_dir'] = checkpoint['train_data_dir']
        network['val_data_dir'] = checkpoint['val_data_dir']
        network['name'] = checkpoint['name']
        network['batch_size'] = checkpoint['batch_size']
        network['features'] = checkpoint['features']
        network['image_size'] = checkpoint['image_size']
        network['image_channels'] = checkpoint['image_channels']
        network['optimizer_name'] = checkpoint['optimizer_name']
        network['arch'] = checkpoint['arch']
        network['bn'] = checkpoint['bn']
        network['checkpoint'] = checkpoint
    else:
        network['resume'] = False
        try:
            network['output_dir'] = config['output_path'] + \
                config['name'] + '/'
            network['name'] = config['name']
            network['epoch_start'] = 0
            network['epoch_end'] = config['epoch']
            network['learning_rate'] = config['learning_rate']
            network['batch_size'] = config['batch_size']
            network['features'] = int(config['features'])
            network['image_size'] = config['image_size']
            network['image_channels'] = config['image_channels']
            network['optimizer_name'] = config['optimizer']
            network['train_data_dir'] = config['train_data_path']
            network['val_data_dir'] = config['val_data_path']
            network['arch'] = config['arch']
            network['bn'] = config['bn']
        except KeyError as err:
            print(f'Configuration: Unspecified field {err}')
            exit(1)
    network['checkpoint_dir'] = network['output_dir'] + 'checkpoints/'
    network['result_dir'] = network['output_dir'] + 'result/'
    check_dir(network['output_dir'])
    check_dir(network['checkpoint_dir'])
    check_dir(network['result_dir'])

    network['logfile_path'] = network['result_dir'] + 'logfile.txt'
    network['performance_path'] = network['result_dir'] + 'performance.txt'
    learning_model = model.AutoEncoder(
        network['image_size'], network['image_channels'], network['features'], network['arch'], network['bn'])
    #learning_model.load_state_dict(torch.load("C:/Users/nelsite/Desktop/Coding_with_Fahim/Topological_Segmentation/TopoSegNetSimple/network/Output/Unet_Training/checkpoints/Unet_Training_49.pth"))
    learning_model = learning_model.to(torch_device)
    
    network['loss_function'] = WeightedBCELoss(one_weight=1,zeros_weight=1)
    # network['loss_function'] = nn.MSELoss()
    optimizer = None
    if network['optimizer_name'] == 'adam':
        optimizer = optim.Adam(learning_model.parameters(),
                               lr=network['learning_rate'])
    elif network['optimizer_name'] == 'sgd':
        optimizer = optim.SGD(learning_model.parameters(),
                              momentum=0.9, weight_decay=1e-2,
                              lr=network['learning_rate'])

    scheduler = optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode='min', factor=0.5, patience=1)

    if (network['resume']):
        learning_model.load_state_dict(network['checkpoint']['model'])
        optimizer.load_state_dict(network['checkpoint']['optimizer'])
        scheduler.load_state_dict(network['checkpoint']['scheduler'])
    
    network['model'] = learning_model
    network['optimizer'] = optimizer
    network['scheduler'] = scheduler
    return network


class WeightedBCELoss:
    def __init__(self, one_weight=1.0, zeros_weight=1.0, reduction="mean"):
        self.reduction = reduction
        self.update_weights(one_weight, zeros_weight)

    def update_weights(self, one_weight, zeros_weight):
        self.weights = torch.FloatTensor([one_weight, zeros_weight])
        self.weights.to(torch_device)

    def _bce(self, x, y):
        weights = -self.weights
        x = torch.clamp(x, min=1e-7, max=1-1e-7)
        y = torch.clamp(y, min=1e-7, max=1-1e-7)
        return weights[1]*y*torch.log(x) + weights[0]*(1-y)*torch.log(1-x)

    def __call__(self, pred, truth):
        loss = self._bce(pred, truth)
        if self.reduction == 'mean':
            return torch.mean(loss)
        if self.reduction == 'sum':
            return torch.sum(loss)
        return loss



def train(adv_params, network, dataloader, withTopo=True):
    
    et = Edges_(adv_params, debug=False)
    #criterionT = GANLoss("vanilla_topo", "sum").to(torch_device)
    criterionT = nn.BCEWithLogitsLoss(reduction="sum")
    loss_function = network['loss_function']
    model = network['model']
    optimizer = network['optimizer']
    model.train()
    running_loss = 0.0
    t_loss = 0.0
    alpha = 0.96
    result_dir = network['result_dir']
    batch_number = 0
    step_num = 0
    for data in dataloader:
        
        step_num = step_num + 1
        print("step_num : ", step_num)
        scalars, label = data

        scalars = scalars.to(torch_device)
        label = label.to(torch_device)
        optimizer.zero_grad()

        p = model(scalars)
        prediction = (p - torch.min(p))/(torch.max(p) - torch.min(p)) *2-1
        
        loss = loss_function(prediction, label)
        running_loss += loss.item()
        print("training loss : ", loss.item())
        
        #loss.backward()
        num_rows = p.size(0)
        s = scalars.cpu().view(
            num_rows, 1, 64, 64).double()
        t = label.cpu().view(
            num_rows, 1, 64, 64).double()

        pred = prediction.cpu().view(
            num_rows, 1, 64, 64).double()
        
        pl = label.cpu().view(
            num_rows, 1, 64, 64).double()
        
        out_image = torch.transpose(torch.stack((s, t, pl,pred)), 0, 1).reshape(
            4*num_rows, 1,  64, 64)
        save_image(out_image.cpu(
        ), f"{result_dir}_batch{batch_number}.png", padding=4, nrow=24)
        batch_number += 1
        
        if withTopo:
            
            tp_wgt   = et.return_tp_weight()
            fake_fix, mean_wasdis = et.fix_with_topo(prediction.detach().cpu().numpy(), label.detach().cpu().numpy(),
                                                     et.return_target_dim(), result_dir, batch_number, num_rows, 
                                                     -1.0, 1.0, blind=et.blind())
            #print( prediction.detach().cpu().numpy().shape, et.return_target_dim())
            fake_fix = torch.from_numpy(fake_fix).to(torch_device)
            
            fake_fix = torch.unsqueeze(fake_fix, 1)
            fake_fix = (fake_fix - torch.min(fake_fix))/(torch.max(fake_fix) - torch.min(fake_fix))
            errT = criterionT(p, fake_fix) * tp_wgt
            t_loss += errT.item()
            print("Topo loss : ", errT.item())
            #errT.backward()
            #total_loss = (alpha * loss) + ((1-alpha) * errT)
            tot_loss = loss + errT
            print("Total loss : ", total_loss.item())
            tot_loss.backward()
            optimizer.step()
        
        else:
            #loss.backward()
            optimizer.step()
    
    total_loss = tot_loss / len(dataloader.dataset)
    training_loss = running_loss / len(dataloader.dataset)
    topo_loss = t_loss / len(dataloader.dataset)
    return [training_loss], [topo_loss], [total_loss]


def validate(network, dataloader, epoch):
    image_size = [network['image_size'], network['image_size']]
    running_loss = 0.0
    tp = 0.0  # true positive
    tn = 0.0  # true negative
    fp = 0.0  # false positive
    fn = 0.0  # false negative

    l1_diff = 0.0
    with torch.no_grad():
        loss_function = network['loss_function']
        model = network['model']
        result_dir = network['result_dir']
        image_channels = network['image_channels']
        model.eval()
        batch_number = 0
        output_image = False
        for i, data in enumerate(dataloader):
            scalars, label = data
            label = label.to(torch_device)
            scalars = scalars.to(torch_device)
            batch_size = label.size(0)

            prediction = model(scalars)
            '''pred = prediction.cpu().detach().numpy()
            pred = np.squeeze(pred)
            fig, ax = plt.subplots(nrows=1, ncols=5, figsize=(12, 4))
            for i in range(5):
                ax[i].imshow(pred[i], cmap='gray')
                ax[i].axis('off')

            plt.show()'''
            
            loss = loss_function(prediction, label)
            running_loss += loss.item() * batch_size
            # log accuracy
            pred = prediction.cpu().view(batch_size, -1).double()
            truth = label.cpu().view(batch_size, -1).double()

            plabel = torch.zeros(pred.size())
            plabel[pred >= 0.3] = 1
            tp += torch.sum(torch.logical_and(plabel == 1, truth == 1).float())
            tn += torch.sum(torch.logical_and(plabel == 0, truth == 0).float())
            fp += torch.sum(torch.logical_and(plabel == 1, truth == 0).float())
            fn += torch.sum(torch.logical_and(plabel == 0, truth == 1).float())

            l1_diff += torch.sum(torch.abs(pred - truth))

            if epoch != "":
                if (epoch == network['epoch_end'] - 1) or (i == len(dataloader) - 1):
                    output_image = True

            #if output_image:
            num_rows = batch_size
            s = scalars.cpu().view(
                num_rows, 1, image_size[1], image_size[0]).double()
            t = label.cpu().view(
                num_rows, 1, image_size[1], image_size[0]).double()

            pred = prediction.cpu().view(
                num_rows, 1, image_size[1], image_size[0]).double()
            
            pl = plabel.cpu().view(
                num_rows, 1, image_size[1], image_size[0]).double()
            
            out_image = torch.transpose(torch.stack((s, t, pl,pred)), 0, 1).reshape(
                4*num_rows, 1,  image_size[1], image_size[0])
            save_image(out_image.cpu(
            ), f"{result_dir}epoch_{epoch}_batch{batch_number}.png", padding=4, nrow=24)
            batch_number += 1
        # end for loop
    # end with nograd
    val_loss = running_loss/len(dataloader.dataset)
    l1_diff /= len(dataloader.dataset)
    tp /= len(dataloader.dataset)
    tn /= len(dataloader.dataset)
    fp /= len(dataloader.dataset)
    fn /= len(dataloader.dataset)
    accuracy = (tp + tn) / (tp + tn + fp + fn)
    precision = (tp)/(tp+fp)
    recall = (tp)/(tp+fn)
    f1 = 2*tp / (2 * tp + fp + fn)

    return [val_loss], [accuracy, precision, recall, f1, l1_diff]


def floats2str(l):
    return ",".join(map(lambda x: f'{x:.6f}', l))


def parameters_count(model):
    total = 0
    total_t = 0
    for p in model.parameters():
        if p.requires_grad:
            total += p.numel()
            total_t += p.numel()
        else:
            total += p.numel()
    return total, total_t


def main():
    adv_params  = return_advanced_params()
    config = parse_args()
    network = initialize_network(config)

    p, pt = parameters_count(network['model'])
    print(f'number of parameters(trainable) {p}({pt})')

    with open(network['output_dir']+'config.json', 'w') as jsonout:
        json.dump(config, jsonout, indent=2)

    train_dataset = ImageBoundary(
        config['train_data_path'], network['image_channels'])
    #i,t = train_dataset[0]
    #print(i,t)
    train_dataloader = DataLoader(
        train_dataset, batch_size=network['batch_size'], shuffle=True)
    val_dataset = ImageBoundary(
        config['val_data_path'], network['image_channels'])
    val_dataloader = DataLoader(
        val_dataset, batch_size=network['batch_size'], shuffle=False)

    if network['resume']:
        logfile = open(network['logfile_path'], 'a')
        perf_log = open(network['performance_path'], 'a')
    else:
        logfile = open(network['logfile_path'], 'w')
        logfile.write('epoch,train_loss,val_loss,topo_loss,total_loss\n')
        perf_log = open(network['performance_path'], 'w')
        perf_log.write(
            'epoch, accuracy, precision, recall, f1, l1_diff_per_image)\n')

    for epoch in tqdm(range(network['epoch_start'], network['epoch_end'])):
        t_loss, topo_loss, total_loss = train(adv_params, network, train_dataloader, True)
        v_loss, performance = validate(network, val_dataloader, epoch)
        network['scheduler'].step(total_loss[0])

        performance = floats2str(performance)
        perf_log.write(f'{epoch},{performance}\n')
        perf_log.flush()

        t_loss = floats2str(t_loss)
        v_loss = floats2str(v_loss)
        topo_loss = floats2str(topo_loss)
        total_loss = floats2str(total_loss)
        
        logfile.write(f'{epoch},{t_loss},{v_loss},{topo_loss},{total_loss}\n')
        logfile.flush()
        if ((epoch+1) % 50 == 0) or epoch == network['epoch_end'] - 1:
            torch.save({
                'epoch': epoch,
                'epoch_end': network['epoch_end'],
                'model': network['model'].state_dict(),
                'optimizer': network['optimizer'].state_dict(),
                'optimizer_name': network['optimizer_name'],
                'scheduler': network['scheduler'].state_dict(),
                'checkpoint_dir': network['checkpoint_dir'],
                'train_data_dir': network['train_data_dir'],
                'val_data_dir': network['val_data_dir'],
                'output_dir': network['output_dir'],
                'name': network['name'],
                'batch_size': network['batch_size'],
                'learning_rate': network['learning_rate'],
                'features': network['features'],
                'image_size': network['image_size'],
                'image_channels': network['image_channels'],
                'arch': network['arch'],
                'bn': network['bn']
            }, f'{network["checkpoint_dir"]}{network["name"]}_{epoch}.pth')
    logfile.close()
    perf_log.close()


if __name__ == '__main__':
    main()

number of parameters(trainable) 7762465(7762465)




step_num :  1
training loss :  0.12208235263824463
Computing 1D 1-wasserstein distance.
Topo loss :  6.768273830413818
Total loss :  6.890356063842773
step_num :  2
training loss :  0.13557016849517822
Computing 1D 1-wasserstein distance.
Topo loss :  6.744208335876465
Total loss :  6.8797783851623535
step_num :  3
training loss :  0.09971626102924347
Computing 1D 1-wasserstein distance.
Topo loss :  6.773454189300537
Total loss :  6.873170375823975
step_num :  4
training loss :  0.09307428449392319
Computing 1D 1-wasserstein distance.
Topo loss :  6.710087299346924
Total loss :  6.80316162109375
step_num :  5
training loss :  0.11126159876585007
Computing 1D 1-wasserstein distance.
Topo loss :  6.733229637145996
Total loss :  6.844491004943848
step_num :  6
training loss :  0.13962586224079132
Computing 1D 1-wasserstein distance.
Topo loss :  6.75391149520874
Total loss :  6.893537521362305
step_num :  7
training loss :  0.1460723876953125
Computing 1D 1-wasserstein distance.
Topo los

Computing 1D 1-wasserstein distance.
Topo loss :  6.911436557769775
Total loss :  8.052356719970703
step_num :  56
training loss :  0.9674803018569946
Computing 1D 1-wasserstein distance.
Topo loss :  6.898138046264648
Total loss :  7.8656182289123535
step_num :  57
training loss :  0.9603796005249023
Computing 1D 1-wasserstein distance.
Topo loss :  6.88740348815918
Total loss :  7.847783088684082
step_num :  58
training loss :  0.9631537795066833
Computing 1D 1-wasserstein distance.
Topo loss :  6.8859171867370605
Total loss :  7.849071025848389
step_num :  59
training loss :  1.0173977613449097
Computing 1D 1-wasserstein distance.
Topo loss :  6.947256088256836
Total loss :  7.964653968811035
step_num :  60
training loss :  1.136591911315918
Computing 1D 1-wasserstein distance.
Topo loss :  6.9290852546691895
Total loss :  8.065677642822266
step_num :  61
training loss :  1.1582225561141968
Computing 1D 1-wasserstein distance.
Topo loss :  6.927536487579346
Total loss :  8.085759162

Computing 1D 1-wasserstein distance.
Topo loss :  6.868361473083496
Total loss :  7.786523818969727
step_num :  111
training loss :  0.8888723254203796
Computing 1D 1-wasserstein distance.
Topo loss :  6.864999294281006
Total loss :  7.753871440887451
step_num :  112
training loss :  1.0058870315551758
Computing 1D 1-wasserstein distance.
Topo loss :  6.842706203460693
Total loss :  7.848593235015869
step_num :  113
training loss :  1.0015558004379272
Computing 1D 1-wasserstein distance.
Topo loss :  6.86182165145874
Total loss :  7.863377571105957
step_num :  114
training loss :  0.9421771168708801
Computing 1D 1-wasserstein distance.
Topo loss :  6.839972972869873
Total loss :  7.7821502685546875
step_num :  115
training loss :  1.1471514701843262
Computing 1D 1-wasserstein distance.
Topo loss :  6.848733901977539
Total loss :  7.995885372161865
step_num :  116
training loss :  0.9339360594749451
Computing 1D 1-wasserstein distance.
Topo loss :  6.856748580932617
Total loss :  7.7906

Computing 1D 1-wasserstein distance.
Topo loss :  6.9216742515563965
Total loss :  7.8618035316467285
step_num :  165
training loss :  1.05824613571167
Computing 1D 1-wasserstein distance.
Topo loss :  6.908208847045898
Total loss :  7.966454982757568
step_num :  166
training loss :  1.0852092504501343
Computing 1D 1-wasserstein distance.
Topo loss :  6.90596342086792
Total loss :  7.991172790527344
step_num :  167
training loss :  1.0822886228561401
Computing 1D 1-wasserstein distance.
Topo loss :  6.88303804397583
Total loss :  7.96532678604126
step_num :  168
training loss :  0.90835040807724
Computing 1D 1-wasserstein distance.
Topo loss :  6.896989822387695
Total loss :  7.80534029006958
step_num :  169
training loss :  1.0747288465499878
Computing 1D 1-wasserstein distance.
Topo loss :  6.909494876861572
Total loss :  7.98422384262085
step_num :  170
training loss :  1.115190863609314
Computing 1D 1-wasserstein distance.
Topo loss :  6.902739524841309
Total loss :  8.017930030822

Computing 1D 1-wasserstein distance.
Topo loss :  6.8712873458862305
Total loss :  7.739677906036377
step_num :  219
training loss :  0.9358305335044861
Computing 1D 1-wasserstein distance.
Topo loss :  6.90731954574585
Total loss :  7.8431501388549805
step_num :  220
training loss :  0.9782935380935669
Computing 1D 1-wasserstein distance.
Topo loss :  6.872949600219727
Total loss :  7.851243019104004
step_num :  221
training loss :  0.9773964285850525
Computing 1D 1-wasserstein distance.
Topo loss :  6.872952461242676
Total loss :  7.850348949432373
step_num :  222
training loss :  1.0325387716293335
Computing 1D 1-wasserstein distance.
Topo loss :  6.87471866607666
Total loss :  7.907257556915283
step_num :  223
training loss :  0.9924383163452148
Computing 1D 1-wasserstein distance.
Topo loss :  6.919709205627441
Total loss :  7.912147521972656
step_num :  224
training loss :  0.7465094923973083
Computing 1D 1-wasserstein distance.
Topo loss :  6.9251179695129395
Total loss :  7.671

Computing 1D 1-wasserstein distance.
Topo loss :  6.8360443115234375
Total loss :  7.807507038116455
step_num :  273
training loss :  0.9427432417869568
Computing 1D 1-wasserstein distance.
Topo loss :  6.822476387023926
Total loss :  7.765219688415527
step_num :  274
training loss :  0.9821323752403259
Computing 1D 1-wasserstein distance.
Topo loss :  6.827972888946533
Total loss :  7.810105323791504
step_num :  275
training loss :  0.9271522760391235
Computing 1D 1-wasserstein distance.
Topo loss :  6.838403224945068
Total loss :  7.765555381774902
step_num :  276
training loss :  0.9099233746528625
Computing 1D 1-wasserstein distance.
Topo loss :  6.8228325843811035
Total loss :  7.7327561378479
step_num :  277
training loss :  1.0128129720687866
Computing 1D 1-wasserstein distance.
Topo loss :  6.792267799377441
Total loss :  7.805080890655518
step_num :  278
training loss :  1.046094298362732
Computing 1D 1-wasserstein distance.
Topo loss :  6.82405424118042
Total loss :  7.870148

Computing 1D 1-wasserstein distance.
Topo loss :  6.852574348449707
Total loss :  7.819463729858398
step_num :  327
training loss :  0.9588813781738281
Computing 1D 1-wasserstein distance.
Topo loss :  6.864874362945557
Total loss :  7.823755741119385
step_num :  328
training loss :  0.8755060434341431
Computing 1D 1-wasserstein distance.
Topo loss :  6.889657974243164
Total loss :  7.765163898468018
step_num :  329
training loss :  0.9251636862754822
Computing 1D 1-wasserstein distance.
Topo loss :  6.8824076652526855
Total loss :  7.8075714111328125
step_num :  330
training loss :  0.8991197943687439
Computing 1D 1-wasserstein distance.
Topo loss :  6.862534523010254
Total loss :  7.761654376983643
step_num :  331
training loss :  0.984118640422821
Computing 1D 1-wasserstein distance.
Topo loss :  6.860263824462891
Total loss :  7.844382286071777
step_num :  332
training loss :  0.9465054869651794
Computing 1D 1-wasserstein distance.
Topo loss :  6.876963138580322
Total loss :  7.823

Computing 1D 1-wasserstein distance.
Topo loss :  6.623136043548584
Total loss :  7.699174880981445
step_num :  381
training loss :  1.1398342847824097
Computing 1D 1-wasserstein distance.
Topo loss :  6.637139320373535
Total loss :  7.776973724365234
step_num :  382
training loss :  1.0602091550827026
Computing 1D 1-wasserstein distance.
Topo loss :  6.662068843841553
Total loss :  7.722278118133545
step_num :  383
training loss :  1.1022194623947144
Computing 1D 1-wasserstein distance.
Topo loss :  6.7049360275268555
Total loss :  7.807155609130859
step_num :  384
training loss :  1.1836950778961182
Computing 1D 1-wasserstein distance.
Topo loss :  6.683860778808594
Total loss :  7.867555618286133
step_num :  385
training loss :  1.158677101135254
Computing 1D 1-wasserstein distance.
Topo loss :  6.663756370544434
Total loss :  7.8224334716796875
step_num :  386
training loss :  1.1897035837173462
Computing 1D 1-wasserstein distance.
Topo loss :  6.665561676025391
Total loss :  7.855

  3%|██▋                                                                             | 1/30 [04:06<1:59:18, 246.83s/it]

step_num :  1
training loss :  0.8265628218650818
Computing 1D 1-wasserstein distance.
Topo loss :  6.794521331787109
Total loss :  7.621084213256836
step_num :  2
training loss :  0.8053750395774841
Computing 1D 1-wasserstein distance.
Topo loss :  6.766937255859375
Total loss :  7.572312355041504
step_num :  3
training loss :  1.096373438835144
Computing 1D 1-wasserstein distance.
Topo loss :  6.758614540100098
Total loss :  7.854988098144531
step_num :  4
training loss :  0.9546833038330078
Computing 1D 1-wasserstein distance.
Topo loss :  6.77022647857666
Total loss :  7.724909782409668
step_num :  5
training loss :  0.915448009967804
Computing 1D 1-wasserstein distance.
Topo loss :  6.794043064117432
Total loss :  7.70949125289917
step_num :  6
training loss :  1.0897878408432007
Computing 1D 1-wasserstein distance.
Topo loss :  6.721987247467041
Total loss :  7.811775207519531
step_num :  7
training loss :  1.007663369178772
Computing 1D 1-wasserstein distance.
Topo loss :  6.775

  3%|██▋                                                                             | 1/30 [04:15<2:03:39, 255.85s/it]


KeyboardInterrupt: 