In [1]:
#!/usr/bin/env python
# -*- coding: utf-8 -*-
from __future__ import print_function
import os
import argparse
import torch
import torch.nn as nn
import torch.optim as optim
import sklearn.metrics as metrics
import numpy as np
import random

from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.autograd import Variable
from torch.utils.data import DataLoader
from tqdm import tqdm


# from SageMix import SageMix
from data import ModelNet40, ScanObjectNN
from model import PointNet, DGCNN
from util import cal_loss, cal_loss_mix, IOStream
import wandb
import torch.nn.functional as F
# import io

import torch
from emd_ import emd_module

In [25]:
class SageMix:
    def __init__(self, args, num_class=40):
        self.num_class = num_class
        self.EMD = emd_module.emdModule()
        self.sigma = args.sigma
        self.beta = torch.distributions.beta.Beta(torch.tensor([args.theta]), torch.tensor([args.theta]))
    
    def mix(self, xyz, label, saliency=None, n_mix=4, theta=0.2):
        """
        Args:
            xyz (B,N,3)
            label (B)
            saliency (B,N): Defaults to None.
        """        
        # label_ori = label.clone()
        # print(xyz.shape)
        B, N, _ = xyz.shape
        # print("saliency based", saliency_based)
        # mapping = self.find_optimal_mapping(xyz, saliency)
        # return 0
        # print(mapping)
        # return 0
        # print(xyz.shape)
        # idxs = mapping[:,1].to(torch.int64) #torch.randperm(B)
        idxs = torch.stack([torch.randperm(B) for _ in range(n_mix)])
        # idxs = torch.argsort(torch.rand(B, n_clouds))

        xyzs = torch.zeros((n_mix, B, N, 3)).cuda()
        for i in range(n_mix):
            if i == 0: xyzs[i] = xyz
            else:
                xyzs[i] = xyz[idxs[i]]


        all_xyz = torch.zeros((n_mix, B, N, 3)).cuda()
        all_xyz[0] = xyzs[0]

        all_saliency = torch.zeros((n_mix, B, N)).cuda()
        all_saliency[0] = saliency
        for i in range(1, n_mix):
            _, ass = self.EMD(xyzs[0], xyzs[i], 0.005, 500)
            #cast ass to long tensor
            ass = ass.type(torch.long)

            xyz_new = torch.zeros_like(xyzs[i]).cuda()
            saliency_new = torch.zeros_like(saliency).cuda()
            
            
            # print(ass,ass.shape)
            for j in range(B):
                # print(ass[j].dtype)
                # print("ass j", ass[j])
                # print("ass j shape", ass[j].shape)
                # print("xyzs shape", xyzs.shape)
                # print("xyzs i shape", xyzs[i].shape)
                all_xyz[i][j] = xyzs[i][j][ass[j]]
                all_saliency[i][j] = saliency[idxs[i]][j][ass[j]]

                # xyz_new[i] = xyzs[j][ass[j]]
                # saliency_new[j] = saliency[idxs[j]][j][ass[j]]
            
            # all_xyz[i] = xyz_new
            # all_saliency[i] = saliency_new
        # print("permuted saliency", saliency[1])

        anchors = torch.zeros(n_mix, B, 3).cuda()

        saliency = saliency/saliency.sum(-1, keepdim=True)
        anc_idx = torch.randint(0, 1024, (B,1)).cuda()
        # anc_idx = torch.multinomial(saliency, 1, replacement=True)
        anchor_ori = all_xyz[0][torch.arange(B), anc_idx[:,0]]
        anchors[0] = anchor_ori
        # # print("anchor shape", anchor_ori.shape)

        anc_idx_new = 0
        perm_saliency_new = 0
        # ker_weight_fix = 0
        for i in range(1, n_mix):
            dists = []
            for j in range(0,i):
                # print("all_xyz", all_xyz[i])
                # print("anchors", anchors)
                sub = all_xyz[i] - anchors[j][:, None, :]
                # subs.append(sub)
                dist = ((sub) ** 2).sum(2).sqrt()
                dists.append(dist)
                # print(dist.shape)
            dist = torch.stack(dists).sum(dim=0)
            
            perm_saliency_new = all_saliency[i] * dist
            perm_saliency_new = perm_saliency_new/perm_saliency_new.sum(-1, keepdim=True)


        #     ## try to fix this at 0
            # anc_idx_new = torch.multinomial(perm_saliency_new, 1, replacement=True)
            anc_idx_new = torch.randint(0, 1024, (B,1)).cuda()
            anchor_perm_new = all_xyz[i][torch.arange(B),anc_idx_new[:,0]]
            anchors[i] = anchor_perm_new
            # sub = perm_new - anchor_ori[:,None,:]
        #     # dist = ((sub) ** 2).sum(2).sqrt()
        #     # perm_saliency = perm_saliency * dist
        #     # perm_saliency = perm_saliency/perm_saliency.sum(-1, keepdim=True)
        # # alpha = self.dirichlet.sample((B,)).cuda()
        pi = torch.distributions.dirichlet.Dirichlet(torch.tensor([theta for i in range(n_mix)])).sample((B,)).cuda()
        # # print("pi shape", pi.shape)
        # # print("pi sum", pi.sum(1))
        

        kerns = torch.zeros(n_mix, B, N).cuda()
        weights = torch.zeros(n_mix, B, N).cuda()
        weights_copy = []
        for i in range(n_mix):
            sub_ori = all_xyz[i] - anchors[i][:,None,:]
            sub_ori = ((sub_ori) ** 2).sum(2).sqrt()
        #     #Eq.(6) for first sample
            ker_weight_ori = torch.exp(-0.5 * (sub_ori ** 2) / (self.sigma ** 2))  #(M,N)
            kerns[i] = ker_weight_ori
        #     # print("kern weight ori", ker_weight_ori.shape)

            weights[i] = ker_weight_ori * pi[:,i][:,None]
            weights_copy.append(weights[i][...,None])

            # ker_weight_fix = ker_weight_ori


        # # weight = (torch.cat([weight_ori[...,None],weight_perm[...,None]],-1)) + 1e-16
        weight = (torch.cat(weights_copy,-1)) + 1e-16
        weight = weight/weight.sum(-1)[...,None]

        weight_old = weight.clone()
        x = torch.zeros((B, N, 3)).cuda()

        for i in range(n_mix):
            x += weight[:, :, i:i+1] * all_xyz[i]
        target = weight.sum(1)
        target = target / target.sum(-1, keepdim=True)

        label_one_hots = torch.zeros(n_mix, B, self.num_class).cuda()
        label_onehot = torch.zeros(B, self.num_class).cuda().scatter(1, label.view(-1, 1), 1)
        label_one_hots[0] = label_onehot
        # print("label_onehot shape", label_onehot.shape)

        label = torch.zeros(B, self.num_class).cuda()
        label += label_one_hots[0] * target[:, 0, None]
        
        for i in range(1, n_mix):
            label_perm_onehot = label_onehot[idxs[i]]
            label += label_perm_onehot * target[:, i, None]
        
        return x, label

In [26]:
def distance(z, dist_type='l2'):
    '''Return distance matrix between vectors'''
    with torch.no_grad():
        diff = z.unsqueeze(1) - z.unsqueeze(0)
        if dist_type[:2] == 'l2':
            A_dist = (diff**2).sum(-1)
            if dist_type == 'l2':
                A_dist = torch.sqrt(A_dist)
            elif dist_type == 'l22':
                pass
        elif dist_type == 'l1':
            A_dist = diff.abs().sum(-1)
        elif dist_type == 'linf':
            A_dist = diff.abs().max(-1)[0]
        else:
            return None
    return A_dist


def calc_A_dist(saliency, theta=0.5):
    sc = saliency.unsqueeze(1)
    # print("sc:",sc.shape)
    # z = F.avg_pool1d(sc, kernel_size=8, stride=1)
    # print("z:",z.shape)
    z = sc
    z_reshape = z.reshape(args.batch_size, -1)
    # print("z_reshape:",z_reshape.shape)
    z_idx_1d = torch.argmax(z_reshape, dim=1)
    z_idx_2d = torch.zeros((args.batch_size, 2), device=z.device)
    z_idx_2d[:, 0] = z_idx_1d // z.shape[-1]
    z_idx_2d[:, 1] = z_idx_1d % z.shape[-1]
    # print("z_idx_2d:",z_idx_2d)
    A_dist = distance(z_idx_2d, dist_type='l1')
    # print("A_dist:", A_dist)

    n_input = saliency.shape[0]
    
    A_base = torch.eye(n_input, device=out.device)

    A_dist = A_dist / torch.sum(A_dist) * n_input
    m_omega = torch.distributions.beta.Beta(theta, theta).sample()
    A = (1 - m_omega) * A_base + m_omega * A_dist
    # print("A", A)
    return A


In [41]:
def _init_():
    if not os.path.exists('checkpoints'):
        os.makedirs('checkpoints')
    if not os.path.exists('checkpoints/'+args.exp_name):
        os.makedirs('checkpoints/'+args.exp_name)
    if not os.path.exists('checkpoints/'+args.exp_name+'/'+'models'):
        os.makedirs('checkpoints/'+args.exp_name+'/'+'models')
    os.system('cp main.py checkpoints'+'/'+args.exp_name+'/'+'main.py.backup')
    os.system('cp model.py checkpoints' + '/' + args.exp_name + '/' + 'model.py.backup')
    os.system('cp util.py checkpoints' + '/' + args.exp_name + '/' + 'util.py.backup')
    os.system('cp data.py checkpoints' + '/' + args.exp_name + '/' + 'data.py.backup')

def train(args, io):
    if args.data == 'MN40':
        dataset = ModelNet40(partition='train', num_points=args.num_points)
        # args.batch_size = len(dataset)
        args.batch_size = 24
        print('args.batch_size:',args.batch_size)
        train_loader = DataLoader(dataset, num_workers=8,
                                batch_size=args.batch_size, shuffle=True, drop_last=True)
        test_loader = DataLoader(ModelNet40(partition='test', num_points=args.num_points), num_workers=8,
                                batch_size=args.test_batch_size, shuffle=True, drop_last=False)
        num_class=40
    elif args.data == 'SONN_easy':
        train_loader = DataLoader(ScanObjectNN(partition='train', num_points=args.num_points, ver="easy"), num_workers=8,
                                batch_size=args.batch_size, shuffle=True, drop_last=True)
        test_loader = DataLoader(ScanObjectNN(partition='test', num_points=args.num_points, ver="easy"), num_workers=8,
                                batch_size=args.test_batch_size, shuffle=True, drop_last=False)
        num_class =15
    elif args.data == 'SONN_hard':
        train_loader = DataLoader(ScanObjectNN(partition='train', num_points=args.num_points, ver="hard"), num_workers=8,
                                batch_size=args.batch_size, shuffle=True, drop_last=True)
        test_loader = DataLoader(ScanObjectNN(partition='test', num_points=args.num_points, ver="hard"), num_workers=8,
                                batch_size=args.test_batch_size, shuffle=True, drop_last=False)
        num_class =15
    
    
    device = torch.device("cuda" if args.cuda else "cpu")

    #Try to load models
    if args.model == 'pointnet':
        model = PointNet(args, num_class).to(device)
    elif args.model == 'dgcnn':
        model = DGCNN(args, num_class).to(device)
    else:
        raise Exception("Not implemented")
    print(str(model))

    model = nn.DataParallel(model)
    print("Let's use", torch.cuda.device_count(), "GPUs!")

    if args.use_sgd:
        print("Use SGD")
        opt = optim.SGD(model.parameters(), lr=args.lr*100, momentum=args.momentum, weight_decay=1e-4)
    else:
        print("Use Adam")
        opt = optim.Adam(model.parameters(), lr=args.lr, weight_decay=1e-4)

    scheduler = CosineAnnealingLR(opt, args.epochs, eta_min=args.lr)
    

    sagemix = SageMix(args, num_class)
    criterion = cal_loss_mix

    mixup = "random" if args.fixed_mixup is None else "fixed {}".format(args.fixed_mixup)

    best_test_acc = 0
    cnt = int(args.save_count)
    print("count:", cnt)
    # return
    for epoch in range(args.epochs):

        ####################
        # Train
        ####################
        train_loss = 0.0
        count = 0.0
        model.train()
        train_pred = []
        train_true = []
        for data, label in tqdm(train_loader):
            data, label = data.to(device), label.to(device).squeeze()
            # print("data shape", data.shape)
            batch_size = data.size()[0]
            
            ####################
            # generate augmented sample
            ####################
            if args.fixed_mixup is None: 
                n_mix = random.randint(2,4)
            else: 
                n_mix = int(args.fixed_mixup)

            
            
            # print("n_mix", n_mix)
            if n_mix > 1:
                model.eval()
                data_var = Variable(data.permute(0,2,1), requires_grad=True)
                logits = model(data_var)
                loss = cal_loss(logits, label, smoothing=False)
                loss.backward()
                opt.zero_grad()
                saliency = torch.sqrt(torch.mean(data_var.grad**2,1))
                
                # saliency_based=False
                # if args.mapping is not 'emd': saliency_based=True
                if(epoch == 30):
                    cnt -= 1
                    if(cnt >= 0):
                        
                        print("count:",cnt)
                        print("saving data in epoch", epoch,"for batch count", cnt)
                        print(data.shape)
                        # save data and label for visualization
                        direc = "Data_for_viz"
                        np.save("{}/data_unmixed_{}.npy".format(direc,count), data.cpu().numpy())
                        np.save("{}/label_unmixed_{}.npy".format(direc,count), label.cpu().numpy())
                data, label = sagemix.mix(data, label, saliency, n_mix)
                if(epoch == 30 and cnt >= 0):
                    np.save("{}/data_mixed_{}.npy".format(direc,cnt), data.cpu().numpy())
                    np.save("{}/label_mixed_{}.npy".format(direc,cnt), label.cpu().numpy())
                    #save the saliency matrix too
                    np.save("{}/saliency_{}.npy".format(direc,cnt), saliency.cpu().numpy())

            
                # mixed_saliency = torch.sqrt(torch.mean(data_var.grad**2,1))
            # print("data shape", data.shape)
            # break
            
            model.train()
            # break
                
            opt.zero_grad()
            logits = model(data.permute(0,2,1))
            if n_mix > 1:
                loss = criterion(logits, label)
            else:
                loss = cal_loss(logits, label)
            loss.backward()
            opt.step()
            preds = logits.max(dim=1)[1]
            count += batch_size
            train_loss += loss.item() * batch_size
        
        scheduler.step()
        outstr = 'Train %d, loss: %.6f' % (epoch, train_loss*1.0/count)
        io.cprint(outstr)

        ####################
        # Test
        ####################
        test_loss = 0.0
        count = 0.0
        model.eval()
        test_pred = []
        test_true = []
        for data, label in tqdm(test_loader):
            data, label = data.to(device), label.to(device).squeeze()
            data = data.permute(0, 2, 1)
            batch_size = data.size()[0]
            logits = model(data)
            loss = cal_loss(logits, label)
            preds = logits.max(dim=1)[1]
            count += batch_size
            test_loss += loss.item() * batch_size
            test_true.append(label.cpu().numpy())
            test_pred.append(preds.detach().cpu().numpy())
        test_true = np.concatenate(test_true)
        test_pred = np.concatenate(test_pred)
        test_acc = metrics.accuracy_score(test_true, test_pred)
        avg_per_class_acc = metrics.balanced_accuracy_score(test_true, test_pred)
        if test_acc >= best_test_acc:
            best_test_acc = test_acc
            torch.save(model.state_dict(), 'checkpoints/%s/models/model.t7' % args.exp_name)
        outstr = 'Test %d, loss: %.6f, test acc: %.6f, test avg acc: %.6f, best test acc: %.6f' % (epoch,
                                                                              test_loss*1.0/count,
                                                                              test_acc,
                                                                              avg_per_class_acc,
                                                                              best_test_acc)
        
        # wandb.log({"Test acc": test_acc, "test avg acc": avg_per_class_acc, "best test acc": best_test_acc, "epoch": epoch})"})
        # wandb.log({"n_mix": n_mix})
        io.cprint(outstr)


In [43]:
# Training settings
parser = argparse.ArgumentParser(description='Point Cloud Recognition')
parser.add_argument('--exp_name', type=str, default='exp', metavar='N',
                    help='Name of the experiment')
parser.add_argument('--model', type=str, default='dgcnn', metavar='N',
                    choices=['pointnet', 'dgcnn'],
                    help='Model to use, [pointnet, dgcnn]')
parser.add_argument('--data', type=str, default='MN40', metavar='N',
                    choices=['MN40', 'SONN_easy', 'SONN_hard']) #SONN_easy : OBJ_ONLY, SONN_hard : PB_T50_RS
parser.add_argument('--batch_size', type=int, default=32, metavar='batch_size',
                    help='Size of batch)')
parser.add_argument('--test_batch_size', type=int, default=16, metavar='batch_size',
                    help='Size of batch)')
parser.add_argument('--epochs', type=int, default=50, metavar='N',
                    help='number of episode to train ')
parser.add_argument('--use_sgd', type=bool, default=True,
                    help='Use SGD')
parser.add_argument('--lr', type=float, default=0.001, metavar='LR',
                    help='learning rate (default: 0.001, 0.1 if using sgd)')
parser.add_argument('--momentum', type=float, default=0.9, metavar='M',
                    help='SGD momentum (default: 0.9)')
parser.add_argument('--no_cuda', type=bool, default=False,
                    help='enables CUDA training')
parser.add_argument('--seed', type=int, default=1, metavar='S',
                    help='random seed (default: 1)')
parser.add_argument('--eval', type=bool,  default=False,
                    help='evaluate the model')
parser.add_argument('--num_points', type=int, default=1024,
                    help='num of points to use')
parser.add_argument('--dropout', type=float, default=0.5,
                    help='dropout rate')
parser.add_argument('--emb_dims', type=int, default=1024, metavar='N',
                    help='Dimension of embeddings')
parser.add_argument('--k', type=int, default=20, metavar='N',
                    help='Num of nearest neighbors to use')
parser.add_argument('--m_omega', type=int, default=0.9,
                    help='omega parameter')
parser.add_argument('--mapping', type=str, default='emd',
                    help='mapping function')
parser.add_argument('--model_path', type=str, default='', metavar='N',
                    help='Pretrained model path')
parser.add_argument('--fixed_mixup', type=str, default=None, metavar='N',
                    help='number of mixes')
parser.add_argument('--save_count', type=str, default="5", metavar='N')                    



parser.add_argument('--sigma', type=float, default=-1) 
parser.add_argument('--theta', type=float, default=0.2) 
args = parser.parse_args([])

print(args)
if args.sigma==-1:
    if args.model=='dgcnn':
        args.sigma=0.3
    elif args.model=="pointnet":
        args.sigma=2.0

_init_()

if args.model=='dgcnn':
    args.use_sgd=True
elif args.model=="pointnet":
    args.use_sgd=False

io = IOStream('checkpoints/' + args.exp_name + '/run.log')
io.cprint(str(args))

args.cuda = not args.no_cuda and torch.cuda.is_available()
torch.manual_seed(args.seed)
if args.cuda:
    io.cprint(
        'Using GPU : ' + str(torch.cuda.current_device()) + ' from ' + str(torch.cuda.device_count()) + ' devices')
    torch.cuda.manual_seed(args.seed)
else:
    io.cprint('Using CPU')

args.fixed_mixup = "3"
train(args, io)

# if not args.eval:
#     train(args, io)
# else:
#     test(args, io)

Namespace(batch_size=32, data='MN40', dropout=0.5, emb_dims=1024, epochs=50, eval=False, exp_name='exp', fixed_mixup=None, k=20, lr=0.001, m_omega=0.9, mapping='emd', model='dgcnn', model_path='', momentum=0.9, no_cuda=False, num_points=1024, save_count='5', seed=1, sigma=-1, test_batch_size=16, theta=0.2, use_sgd=True)
Namespace(batch_size=32, data='MN40', dropout=0.5, emb_dims=1024, epochs=50, eval=False, exp_name='exp', fixed_mixup=None, k=20, lr=0.001, m_omega=0.9, mapping='emd', model='dgcnn', model_path='', momentum=0.9, no_cuda=False, num_points=1024, save_count='5', seed=1, sigma=0.3, test_batch_size=16, theta=0.2, use_sgd=True)
Using GPU : 0 from 6 devices
args.batch_size: 24
DGCNN(
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (bn3): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (bn4): BatchNorm2d(256, eps=1

100%|██████████| 410/410 [01:48<00:00,  3.77it/s]


Train 0, loss: 3.492956


100%|██████████| 155/155 [00:03<00:00, 42.10it/s]


Test 0, loss: 3.181736, test acc: 0.220827, test avg acc: 0.136250, best test acc: 0.220827


100%|██████████| 410/410 [01:48<00:00,  3.78it/s]


Train 1, loss: 3.104944


100%|██████████| 155/155 [00:03<00:00, 41.61it/s]


Test 1, loss: 2.849949, test acc: 0.350486, test avg acc: 0.232227, best test acc: 0.350486


100%|██████████| 410/410 [01:48<00:00,  3.78it/s]


Train 2, loss: 3.003427


100%|██████████| 155/155 [00:04<00:00, 37.10it/s]


Test 2, loss: 2.818527, test acc: 0.364668, test avg acc: 0.256000, best test acc: 0.364668


100%|██████████| 410/410 [01:48<00:00,  3.79it/s]


Train 3, loss: 2.927179


100%|██████████| 155/155 [00:03<00:00, 41.31it/s]


Test 3, loss: 2.664188, test acc: 0.426661, test avg acc: 0.306308, best test acc: 0.426661


100%|██████████| 410/410 [01:48<00:00,  3.79it/s]


Train 4, loss: 2.883504


100%|██████████| 155/155 [00:03<00:00, 42.03it/s]


Test 4, loss: 2.625233, test acc: 0.487844, test avg acc: 0.348337, best test acc: 0.487844


100%|██████████| 410/410 [01:48<00:00,  3.79it/s]


Train 5, loss: 2.825877


100%|██████████| 155/155 [00:03<00:00, 38.95it/s]


Test 5, loss: 2.501624, test acc: 0.491896, test avg acc: 0.366285, best test acc: 0.491896


100%|██████████| 410/410 [01:48<00:00,  3.77it/s]


Train 6, loss: 2.792876


100%|██████████| 155/155 [00:03<00:00, 44.67it/s]


Test 6, loss: 2.479428, test acc: 0.499595, test avg acc: 0.360058, best test acc: 0.499595


100%|██████████| 410/410 [01:48<00:00,  3.77it/s]


Train 7, loss: 2.740571


100%|██████████| 155/155 [00:04<00:00, 38.24it/s]


Test 7, loss: 2.369575, test acc: 0.546596, test avg acc: 0.397035, best test acc: 0.546596


100%|██████████| 410/410 [01:49<00:00,  3.76it/s]


Train 8, loss: 2.708007


100%|██████████| 155/155 [00:03<00:00, 44.76it/s]


Test 8, loss: 2.276463, test acc: 0.648703, test avg acc: 0.521837, best test acc: 0.648703


100%|██████████| 410/410 [01:48<00:00,  3.77it/s]


Train 9, loss: 2.675695


100%|██████████| 155/155 [00:04<00:00, 38.27it/s]


Test 9, loss: 2.403428, test acc: 0.572528, test avg acc: 0.460029, best test acc: 0.648703


100%|██████████| 410/410 [01:48<00:00,  3.78it/s]


Train 10, loss: 2.625263


100%|██████████| 155/155 [00:03<00:00, 43.85it/s]


Test 10, loss: 2.283746, test acc: 0.683549, test avg acc: 0.574808, best test acc: 0.683549


100%|██████████| 410/410 [01:48<00:00,  3.79it/s]


Train 11, loss: 2.611473


100%|██████████| 155/155 [00:03<00:00, 43.12it/s]


Test 11, loss: 2.321327, test acc: 0.646272, test avg acc: 0.516256, best test acc: 0.683549


100%|██████████| 410/410 [01:48<00:00,  3.78it/s]


Train 12, loss: 2.581792


100%|██████████| 155/155 [00:03<00:00, 41.86it/s]


Test 12, loss: 2.309464, test acc: 0.656807, test avg acc: 0.534169, best test acc: 0.683549


100%|██████████| 410/410 [01:49<00:00,  3.75it/s]


Train 13, loss: 2.572391


100%|██████████| 155/155 [00:03<00:00, 39.51it/s]


Test 13, loss: 2.318159, test acc: 0.658023, test avg acc: 0.551750, best test acc: 0.683549


100%|██████████| 410/410 [01:48<00:00,  3.77it/s]


Train 14, loss: 2.542800


100%|██████████| 155/155 [00:03<00:00, 43.27it/s]


Test 14, loss: 2.254232, test acc: 0.694895, test avg acc: 0.572128, best test acc: 0.694895


100%|██████████| 410/410 [01:48<00:00,  3.77it/s]


Train 15, loss: 2.551684


100%|██████████| 155/155 [00:03<00:00, 40.50it/s]


Test 15, loss: 2.298675, test acc: 0.674635, test avg acc: 0.565919, best test acc: 0.694895


100%|██████████| 410/410 [01:48<00:00,  3.79it/s]


Train 16, loss: 2.508439


100%|██████████| 155/155 [00:03<00:00, 42.38it/s]


Test 16, loss: 2.249105, test acc: 0.708671, test avg acc: 0.592640, best test acc: 0.708671


100%|██████████| 410/410 [01:49<00:00,  3.76it/s]


Train 17, loss: 2.492547


100%|██████████| 155/155 [00:03<00:00, 41.05it/s]


Test 17, loss: 2.350362, test acc: 0.670178, test avg acc: 0.547669, best test acc: 0.708671


100%|██████████| 410/410 [01:48<00:00,  3.77it/s]


Train 18, loss: 2.478147


100%|██████████| 155/155 [00:03<00:00, 41.27it/s]


Test 18, loss: 2.194626, test acc: 0.729741, test avg acc: 0.594192, best test acc: 0.729741


100%|██████████| 410/410 [01:48<00:00,  3.78it/s]


Train 19, loss: 2.454325


100%|██████████| 155/155 [00:03<00:00, 43.63it/s]


Test 19, loss: 2.246939, test acc: 0.714344, test avg acc: 0.607459, best test acc: 0.729741


100%|██████████| 410/410 [01:48<00:00,  3.77it/s]


Train 20, loss: 2.464678


100%|██████████| 155/155 [00:03<00:00, 43.45it/s]


Test 20, loss: 2.344118, test acc: 0.741086, test avg acc: 0.625965, best test acc: 0.741086


100%|██████████| 410/410 [01:49<00:00,  3.74it/s]


Train 21, loss: 2.417540


100%|██████████| 155/155 [00:03<00:00, 38.83it/s]


Test 21, loss: 2.331160, test acc: 0.730146, test avg acc: 0.616855, best test acc: 0.741086


100%|██████████| 410/410 [01:50<00:00,  3.71it/s]


Train 22, loss: 2.403279


100%|██████████| 155/155 [00:03<00:00, 42.33it/s]


Test 22, loss: 2.134927, test acc: 0.754862, test avg acc: 0.645884, best test acc: 0.754862


100%|██████████| 410/410 [01:49<00:00,  3.74it/s]


Train 23, loss: 2.403116


100%|██████████| 155/155 [00:03<00:00, 41.40it/s]


Test 23, loss: 2.380321, test acc: 0.758509, test avg acc: 0.668924, best test acc: 0.758509


100%|██████████| 410/410 [01:49<00:00,  3.75it/s]


Train 24, loss: 2.397855


100%|██████████| 155/155 [00:03<00:00, 41.15it/s]


Test 24, loss: 2.356671, test acc: 0.747974, test avg acc: 0.656035, best test acc: 0.758509


100%|██████████| 410/410 [01:49<00:00,  3.76it/s]


Train 25, loss: 2.346200


100%|██████████| 155/155 [00:03<00:00, 42.76it/s]


Test 25, loss: 2.257543, test acc: 0.778363, test avg acc: 0.669326, best test acc: 0.778363


100%|██████████| 410/410 [01:49<00:00,  3.75it/s]


Train 26, loss: 2.359623


100%|██████████| 155/155 [00:03<00:00, 40.86it/s]


Test 26, loss: 2.439531, test acc: 0.762561, test avg acc: 0.651070, best test acc: 0.778363


100%|██████████| 410/410 [01:48<00:00,  3.76it/s]


Train 27, loss: 2.349003


100%|██████████| 155/155 [00:03<00:00, 39.13it/s]


Test 27, loss: 2.218172, test acc: 0.779984, test avg acc: 0.674948, best test acc: 0.779984


100%|██████████| 410/410 [01:49<00:00,  3.75it/s]


Train 28, loss: 2.323850


100%|██████████| 155/155 [00:03<00:00, 45.10it/s]


Test 28, loss: 2.202474, test acc: 0.765397, test avg acc: 0.657628, best test acc: 0.779984


100%|██████████| 410/410 [01:48<00:00,  3.80it/s]


Train 29, loss: 2.310728


100%|██████████| 155/155 [00:03<00:00, 46.07it/s]


Test 29, loss: 2.130177, test acc: 0.781199, test avg acc: 0.681610, best test acc: 0.781199


  0%|          | 0/410 [00:00<?, ?it/s]

count: 4
saving data in epoch 30 for batch count 4
torch.Size([24, 1024, 3])


  0%|          | 1/410 [00:00<06:19,  1.08it/s]

count: 3
saving data in epoch 30 for batch count 3
torch.Size([24, 1024, 3])


  0%|          | 2/410 [00:01<03:39,  1.85it/s]

count: 2
saving data in epoch 30 for batch count 2
torch.Size([24, 1024, 3])


  1%|          | 3/410 [00:01<02:48,  2.41it/s]

count: 1
saving data in epoch 30 for batch count 1
torch.Size([24, 1024, 3])


  1%|          | 4/410 [00:01<02:22,  2.84it/s]

count: 0
saving data in epoch 30 for batch count 0
torch.Size([24, 1024, 3])


 56%|█████▋    | 231/410 [01:01<00:47,  3.75it/s]


KeyboardInterrupt: 

In [42]:
!rm Data_for_viz/*