In [1]:
!nvidia-smi

Tue Dec 28 12:54:23 2021       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 470.82.01    Driver Version: 470.82.01    CUDA Version: 11.5     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  NVIDIA A100-SXM...  Off  | 00000000:01:00.0 Off |                    0 |
| N/A   48C    P0    63W / 275W |  20490MiB / 40536MiB |      0%      Default |
|                               |                      |             Disabled |
+-------------------------------+----------------------+----------------------+
|   1  NVIDIA A100-SXM...  Off  | 00000000:47:00.0 Off |                    0 |
| N/A   58C    P0   268W / 275W |  26102MiB / 40536MiB |     85%      Default |
|       

In [3]:
import random
import os
import os.path
import numpy as np
import math
from tqdm import tqdm
import random
import copy
import sys

import torch as t
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
import torch.nn.functional as F 

import cv2
from scipy.ndimage.interpolation import map_coordinates
from scipy.ndimage.filters import gaussian_filter



In [6]:
device = "cuda:2"
device1 = "cuda:2"
device2 = "cuda:2"

train_path = "Data/train_data_tcia.pth"
test_path = "Data/test_data_tcia.pth"

file_name = "anatomy_tcia_benchmark.txt"

#epoch150_path_name = "anatomy_tcia_150.pth"
#epoch150_50_path_name = "anatomy_tcia_150+50.pth"

weight_path = "Weights/anatomy_tcia_150+50.pth"
#test_path = "Data/test_data_tcia.pth"

#anatomy_epoch1 = 150
#anatomy_epoch2 = 50
#print("total epochs", anatomy_epoch1, anatomy_epoch2)


In [7]:
class HaN_Dataset(Dataset):
    
    def __init__(self, root_dir=None, transform=False, alpha=1000, sigma=30, alpha_affine=0.04):
        super().__init__()
        self.path = root_dir
        self.datas = t.load(self.path)
        
        self.transform = transform
        self.alpha = alpha
        self.sigma = sigma
        self.alpha_affine = alpha_affine
    
    def __getitem__(self, index):
        data = self.datas[index]
        img = data['img'].numpy().astype(np.float32)
        
        if not self.transform:
            masklst = []
            for mask in data['mask']:
                if mask is None:
                    mask = np.zeros((1,img.shape[0],img.shape[1],img.shape[2])).astype(np.uint8)
                masklst.append(mask.astype(np.uint8).reshape((1,img.shape[0],img.shape[1],img.shape[2]))) 
            mask0 = np.zeros_like(masklst[0]).astype(np.uint8)
            for mask in masklst:
                mask0 = np.logical_or(mask0, mask).astype(np.uint8)
            mask0 = 1 - mask0
            return t.from_numpy(img.reshape((1, img.shape[0], img.shape[1], img.shape[2]))), t.from_numpy(np.concatenate([mask0]+masklst, axis=0)), True
        
        im_merge = np.concatenate([img[...,None]]+[mask.astype(np.float32)[...,None] for mask in data['mask']], axis=3)
        # Apply transformation on image
        im_merge_t, new_img = self.elastic_transform3Dv2(im_merge,self.alpha,self.sigma,min(im_merge.shape[1:-1])*self.alpha_affine)
        # Split image and mask ::2, ::2, ::2
        im_t = im_merge_t[...,0]
        im_mask_t = im_merge_t[..., 1:].astype(np.uint8).transpose(3, 0, 1, 2)
        mask0 = np.zeros_like(im_mask_t[0, :, :, :]).reshape((1,)+im_mask_t.shape[1:]).astype(np.uint8)
        im_mask_t_lst = []
        flagvect = np.ones((8,), np.float32)
        retflag = True
        for i in range(7):
            im_mask_t_lst.append(im_mask_t[i,:,:,:].reshape((1,)+im_mask_t.shape[1:]))
            if im_mask_t[i,:,:,:].max() != 1: 
                retflag = False
                flagvect[i+1] = 0
            mask0 = np.logical_or(mask0, im_mask_t[i,:,:,:]).astype(np.uint8)
        if not retflag: flagvect[0] = 0
        mask0 = 1 - mask0
        return t.from_numpy(im_t.reshape((1,)+im_t.shape[:3])), t.from_numpy(np.concatenate([mask0]+im_mask_t_lst, axis=0)), flagvect
        
    def __len__(self):
        return len(self.datas)
    
    def elastic_transform3Dv2(self, image, alpha, sigma, alpha_affine, random_state=None):
        """Elastic deformation of images as described in [Simard2003]_ (with modifications).
        .. [Simard2003] Simard, Steinkraus and Platt, "Best Practices for
             Convolutional Neural Networks applied to Visual Document Analysis", in
             Proc. of the International Conference on Document Analysis and
             Recognition, 2003.
         Based on https://gist.github.com/erniejunior/601cdf56d2b424757de5
         From https://www.kaggle.com/bguberfain/elastic-transform-for-data-augmentation
        """
        # affine and deformation must be slice by slice and fixed for slices
        if random_state is None:
            random_state = np.random.RandomState(None)
        shape = image.shape # image is contatenated, the first channel [:,:,:,0] is the image, the second channel 
        # [:,:,:,1] is the mask. The two channel are under the same tranformation.
        shape_size = shape[:-1] # z y x
        # Random affine
        shape_size_aff = shape[1:-1] # y x
        center_square = np.float32(shape_size_aff) // 2
        square_size = min(shape_size_aff) // 3
        pts1 = np.float32([center_square + square_size, [center_square[0]+square_size, center_square[1]-square_size], center_square - square_size])
        pts2 = pts1 + random_state.uniform(-alpha_affine, alpha_affine, size=pts1.shape).astype(np.float32)
        M = cv2.getAffineTransform(pts1, pts2)
        new_img = np.zeros_like(image)
        for i in range(shape[0]):
            new_img[i,:,:,0] = cv2.warpAffine(image[i,:,:,0], M, shape_size_aff[::-1], borderMode=cv2.BORDER_CONSTANT, borderValue=0.)
            for j in range(1, 8):
                new_img[i,:,:,j] = cv2.warpAffine(image[i,:,:,j], M, shape_size_aff[::-1], flags=cv2.INTER_NEAREST, borderMode=cv2.BORDER_TRANSPARENT, borderValue=0)
        dx = gaussian_filter((random_state.rand(*shape[1:-1]) * 2 - 1), sigma) * alpha
        dy = gaussian_filter((random_state.rand(*shape[1:-1]) * 2 - 1), sigma) * alpha
        x, y = np.meshgrid(np.arange(shape_size_aff[1]), np.arange(shape_size_aff[0]))
        indices = np.reshape(y+dy, (-1, 1)), np.reshape(x+dx, (-1, 1))
        new_img2 = np.zeros_like(image)
        for i in range(shape[0]):
            new_img2[i,:,:,0] = map_coordinates(new_img[i,:,:,0], indices, order=1, mode='constant').reshape(shape[1:-1])
            for j in range(1, 8):
                new_img2[i,:,:,j] = map_coordinates(new_img[i,:,:,j], indices, order=0, mode='constant').reshape(shape[1:-1])
        return np.array(new_img2), new_img
# %%


traindataset = HaN_Dataset(train_path, transform=True)
traindataloader = DataLoader(traindataset, batch_size=1, shuffle=True)
testdataset = HaN_Dataset(test_path, transform=False)
testdataloader = DataLoader(testdataset, batch_size=1)
print(len(testdataloader))


20


In [8]:
#network
from torch import nn
import torch.nn.functional as F
from scipy.spatial.distance import dice
def conv3x3x3(in_planes, out_planes, stride=1):
    "3x3x3 convolution with padding"
    return nn.Conv3d(in_planes, out_planes, kernel_size=3, stride=stride,
                     padding=1, bias=False)
class BasicBlock3D(nn.Module):
    def __init__(self, inplanes, planes, stride=1):
        super(BasicBlock3D, self).__init__()
        self.conv1 = conv3x3x3(inplanes, planes, stride)
        self.bn1 = nn.BatchNorm3d(planes)
        self.relu = nn.LeakyReLU(inplace=True)
        self.conv2 = conv3x3x3(planes, planes)
        self.bn2 = nn.BatchNorm3d(planes)
        if inplanes != planes:
            self.downsample = nn.Sequential(nn.Conv3d(inplanes, planes, kernel_size=1, stride=stride, bias=False),
                                            nn.BatchNorm3d(planes))
        else:
            self.downsample = lambda x: x
        self.stride = stride       
    def forward(self, x):
#         print(x.size())
        residual = self.downsample(x)
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
#         print(x.size(), residual.size(), out.size())
        out += residual
        out = self.relu(out)
        return out
def Deconv3x3x3(in_planes, out_planes, stride=2):
    "3x3x3 deconvolution with padding"
    return nn.ConvTranspose3d(in_planes, out_planes, kernel_size=2, stride=stride)

class SELayer3D(nn.Module):
    def __init__(self, channel, reduction=15):
        super(SELayer3D, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool3d(1)
        self.fc = nn.Sequential(
                nn.Linear(channel, channel // reduction),
                nn.LeakyReLU(inplace=True),
                nn.Linear(channel // reduction, channel),
                nn.Sigmoid())
    def forward(self, x):
        b, c, _, _, _ = x.size()
        y = self.avg_pool(x).view(b, c)
        y = self.fc(y).view(b, c, 1, 1, 1)
        return x * y
class SEBasicBlock3D(nn.Module):
    expansion = 1
    def __init__(self, inplanes, planes, stride=1, downsample=None, reduction=15):
        super(SEBasicBlock3D, self).__init__()
        self.conv1 = conv3x3x3(inplanes, planes, stride)
        self.bn1 = nn.BatchNorm3d(planes)
        self.relu = nn.LeakyReLU(inplace=True)
        self.conv2 = conv3x3x3(planes, planes, 1)
        self.bn2 = nn.BatchNorm3d(planes)
        self.se = SELayer3D(planes, reduction)
        if inplanes != planes:
            self.downsample = nn.Sequential(nn.Conv3d(inplanes, planes, kernel_size=1, stride=stride, bias=False),
                                            nn.BatchNorm3d(planes))
        else:
            self.downsample = lambda x: x
        self.stride = stride
    def forward(self, x):
        residual = x
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.se(out)
#         if self.downsample is not None:
        residual = self.downsample(x)
        out += residual
        out = self.relu(out)
        return out
class UpSEBasicBlock3D(nn.Module):
    def __init__(self, inplanes1, inplanes2, planes, stride=1, downsample=None, reduction=16):
        super(UpSEBasicBlock3D, self).__init__()
        inplanes3 = inplanes1 + inplanes2
        if stride == 2:
            self.deconv1 = Deconv3x3x3(inplanes1, inplanes1//2)
            inplanes3 = inplanes1 // 2 + inplanes2
        self.stride = stride
        # self.conv1x1x1 = nn.Conv3d(inplanes2, planes, kernel_size=1, stride=1)#, padding=1)
        self.conv1 = conv3x3x3(inplanes3, planes)#, stride)
        self.bn1 = nn.BatchNorm3d(planes)
        self.relu = nn.LeakyReLU(inplace=True)
        self.conv2 = conv3x3x3(planes, planes)
        self.bn2 = nn.BatchNorm3d(planes)
        self.se = SELayer3D(planes, reduction)
        if inplanes3 != planes:
            self.downsample = nn.Sequential(nn.Conv3d(inplanes3, planes, kernel_size=1, stride=stride, bias=False),
                                            nn.BatchNorm3d(planes))
        else:
            self.downsample = lambda x: x
        self.stride = stride
    def forward(self, x1, x2):
#         print(x1.size(), x2.size())
        if self.stride == 2: x1 = self.deconv1(x1)
        # x2 = self.conv1x1x1(x2)
        #print(x1.size(), x2.size())
        out = t.cat([x1, x2], dim=1) #x1 + x2
        residual = self.downsample(out)
        #print(residual.size(), x1.size(), x2.size())
        out = self.conv1(out)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.se(out)
        #print(out.size(), residual.size())
        out += residual
        out = self.relu(out)
        return out
class UpBasicBlock3D(nn.Module):
    def __init__(self, inplanes1, inplanes2, planes, stride=2):
        super(UpBasicBlock3D, self).__init__()
        inplanes3 = inplanes1 + inplanes2
        if stride == 2:
            self.deconv1 = Deconv3x3x3(inplanes1, inplanes1//2)
            inplanes3 = inplanes1//2 + inplanes2
        self.stride = stride
        # elif inplanes1 != planes:
            # self.deconv1 = nn.Conv3d(inplanes1, planes, kernel_size=1, stride=1)
        # self.conv1x1x1 = nn.Conv3d(inplanes2, planes, kernel_size=1, stride=1)#, padding=1)
        self.conv1 = conv3x3x3(inplanes3, planes)#, stride)
        self.bn1 = nn.BatchNorm3d(planes)
        self.relu = nn.LeakyReLU(inplace=True)
        self.conv2 = conv3x3x3(planes, planes)
        self.bn2 = nn.BatchNorm3d(planes)
        if inplanes3 != planes:
            self.downsample = nn.Sequential(nn.Conv3d(inplanes3, planes, kernel_size=3, stride=1, padding=1, bias=False),
                                            nn.BatchNorm3d(planes))
        else:
            self.downsample = lambda x: x
        self.stride = stride
    def forward(self, x1, x2):
#         print(x1.size(), x2.size())
        if self.stride == 2: x1 = self.deconv1(x1)
        #print(self.stride, x1.size(), x2.size())
        out = t.cat([x1, x2], dim=1)
        residual = self.downsample(out)
        #print(out.size(), residual.size())
        out = self.conv1(out)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        out += residual
        out = self.relu(out)
        return out
class ResNetUNET3D(nn.Module):
    def __init__(self, block, upblock, upblock1, n_size, num_classes=2, in_channel=1): # BasicBlock, 3
        super(ResNetUNET3D, self).__init__()
        self.inplane = 28
        self.conv1 = nn.Conv3d(in_channel, self.inplane, kernel_size=3, stride=2, padding=1, bias=False)
        self.bn1 = nn.BatchNorm3d(self.inplane)
        self.relu = nn.LeakyReLU(inplace=True)
        self.layer1 = self._make_layer(block, 30, blocks=n_size, stride=1)
        self.layer2 = self._make_layer(block, 32, blocks=n_size, stride=1)
        self.layer3 = self._make_layer(block, 34, blocks=n_size, stride=1)
        self.layer4 = upblock(34, 32, 32, stride=1)
        self.inplane = 32
        self.layer5 = self._make_layer(block, 32, blocks=n_size-1, stride=1)
        self.layer6 = upblock(32, 30, 30, stride=1)
        self.inplane = 30
        self.layer7 = self._make_layer(block, 30, blocks=n_size-1, stride=1)
        self.layer8 = upblock(30, 28, 28, stride=1)
        self.inplane = 28
        self.layer9 = self._make_layer(block, 28, blocks=n_size-1, stride=1)
        self.inplane = 28
        self.layer10 = upblock1(28, 1, 14, stride=2)
        self.layer11 = nn.Sequential(#nn.Conv3d(16, 14, kernel_size=3, stride=1, padding=1, bias=True),
                                     #nn.ReLU(inplace=True),
                                     nn.Conv3d(14, num_classes, kernel_size=3, stride=1, padding=1, bias=True))
#         self.outconv = nn.ConvTranspose3d(self.inplane, num_classes, 2, stride=2)
        self.initialize()
    def initialize(self):
        for m in self.modules():
            if isinstance(m, nn.Conv3d):
                nn.init.kaiming_normal_(m.weight)
            elif isinstance(m, nn.BatchNorm3d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.ConvTranspose3d):
                nn.init.kaiming_normal_(m.weight)
    def _make_layer(self, block, planes, blocks, stride):
        strides = [stride] + [1] * (blocks - 1)
        layers = []
        for stride in strides:
            layers.append(block(self.inplane, planes, stride))
            self.inplane = planes
        return nn.Sequential(*layers)
    def forward(self, x0):
        x = self.conv1(x0) # 16 1/2 
        x = self.bn1(x)
        x1 = self.relu(x)

        x2 = self.layer1(x1) # 16 1/4 16 1/4 res 16 1/4 - 16 1/4 16 1/4 res 16 1/4 - 16 1/4 16 1/4 res 16 1/4
        x3 = self.layer2(x2) # 32 1/8 32 1/8 res 32 1/8 - 32 1/8 32 1/8 res 32 1/8 - 32 1/8 32 1/8 res 32 1/8
        x4 = self.layer3(x3) # 64 1/16 64 1/16 res 64 1/16 - 64 1/16 64 1/16 res 64 1/16 - 64 1/16 64 1/16 res 64 1/16
#         print('x4', x4.size())
        x5 = self.layer4(x4, x3) # 16 1/8 48 1/8 32 1/8 32 1/8 res 32 1/8 - 32 1/8 32 1/8 res 32 1/8 - 32 1/8 32 1/8 res 32 1/8
        x5 = self.layer5(x5)
        x6 = self.layer6(x5, x2) # 8 1/4 24 1/4 16 1/4 16 1/4 res 16 1/4 - 16 1/4 16 1/4 res 16 1/4 - 16 1/4 16 1/4 res 16 1/4
        x6 = self.layer7(x6)
        x7 = self.layer8(x6, x1) # 4 1/2 20 1/2 16 1/2 16 1/2 res 16 1/2 - 16 1/2 16 1/2 res 16 1/2 - 16 1/2 16 1/2 res 16 1/2
        x7 = self.layer9(x7)
        x8 = self.layer10(x7, x0)
        x9 = self.layer11(x8)
#         print(x0.size(), x.size(), x1.size(), x2.size(), x3.size(), x4.size(), x5.size(), x6.size(), \
#               x7.size(), x8.size(), x9.size())
        return F.softmax(x9, dim=1)

In [9]:
#Loss
def tversky_loss_wmask(y_pred, y_true, flagvec):
    alpha = 0.5
    beta  = 0.5
    ones = t.ones_like(y_pred) #K.ones(K.shape(y_true))
#     print(type(ones.data), type(y_true.data), type(y_pred.data), ones.size(), y_pred.size())
    p0 = y_pred      # proba that voxels are class i
    p1 = ones-y_pred # proba that voxels are not class i
    g0 = y_true.type(t.cuda.FloatTensor)
    g1 = ones-g0
    num = t.sum(t.sum(t.sum(t.sum(p0*g0, 4),3),2),0) #(0,2,3,4)) #K.sum(p0*g0, (0,1,2,3))
    den = num + alpha*t.sum(t.sum(t.sum(t.sum(p0*g1,4),3),2),0) + beta*t.sum(t.sum(t.sum(t.sum(p1*g0,4),3),2),0) #(0,2,3,4))

    T = t.sum((num * flagvec.to(device1))/(den+1e-5))

#     Ncl = y_pred.size(1)*1.0
#     print(Ncl, T)
    return t.sum(flagvec.to(device1))-T


def focal(y_pred, y_true, flagvec):
    retv = - t.mean(t.mean(t.mean(t.mean(t.log(t.clamp(y_pred,1e-6,1))*y_true.type(t.cuda.FloatTensor)*t.pow(1-y_pred,2),4),3),2),0)\
        * flagvec.to(device1)
    return t.sum(retv)

In [10]:
#Performance metric
def caldice(y_pred, y_true):
#     print(y_pred.sum(), y_true.sum())
    y_pred = y_pred.data.cpu().numpy().transpose(1,0,2,3,4) # inference should be arg max
    y_pred = np.argmax(y_pred, axis=0).squeeze() # z y x
    y_true = y_true.data.numpy().transpose(1,0,2,3,4).squeeze() # .cpu()
    avgdice = []
    y_pred_1 = y_pred==1
    y_true_1 = y_true[1,:,:,:]
    if y_pred_1.sum() + y_true_1.sum() == 0: avgdice.append(-1)
    else: avgdice.append(2.*(np.logical_and(y_pred_1, y_true_1).sum()) / (1.0*(y_pred_1.sum() + y_true_1.sum())))
    
    y_pred_1 = y_pred==2
    y_true_1 = y_true[2,:,:,:]
    if y_pred_1.sum() + y_true_1.sum() == 0: avgdice.append(-1)
    else: avgdice.append(2.*(np.logical_and(y_pred_1, y_true_1).sum()) / (1.0*(y_pred_1.sum() + y_true_1.sum())))
    
    y_pred_1 = y_pred==3
    y_true_1 = y_true[3,:,:,:]
    if y_pred_1.sum() + y_true_1.sum() == 0: avgdice.append(-1)
    else: avgdice.append(2.*(np.logical_and(y_pred_1, y_true_1).sum()) / (1.0*(y_pred_1.sum() + y_true_1.sum())))
    
    y_pred_1 = y_pred==4
    y_true_1 = y_true[4,:,:,:]
    if y_pred_1.sum() + y_true_1.sum() == 0: avgdice.append(-1)
    else: avgdice.append(2.*(np.logical_and(y_pred_1, y_true_1).sum()) / (1.0*(y_pred_1.sum() + y_true_1.sum())))
    
    y_pred_1 = y_pred==5
    y_true_1 = y_true[5,:,:,:]
    if y_pred_1.sum() + y_true_1.sum() == 0: avgdice.append(-1)
    else: avgdice.append(2.*(np.logical_and(y_pred_1, y_true_1).sum()) / (1.0*(y_pred_1.sum() + y_true_1.sum())))
    
    y_pred_1 = y_pred==6
    y_true_1 = y_true[6,:,:,:]
    if y_pred_1.sum() + y_true_1.sum() == 0: avgdice.append(-1)
    else: avgdice.append(2.*(np.logical_and(y_pred_1, y_true_1).sum()) / (1.0*(y_pred_1.sum() + y_true_1.sum())))
    
    y_pred_1 = y_pred==7
    y_true_1 = y_true[7,:,:,:]
    if y_pred_1.sum() + y_true_1.sum() == 0: avgdice.append(-1)
    else: avgdice.append(2.*(np.logical_and(y_pred_1, y_true_1).sum()) / (1.0*(y_pred_1.sum() + y_true_1.sum())))
    """
    y_pred_1 = y_pred==8
    y_true_1 = y_true[8,:,:,:]
    if y_pred_1.sum() + y_true_1.sum() == 0: avgdice.append(-1)
    else: avgdice.append(2.*(np.logical_and(y_pred_1, y_true_1).sum()) / (1.0*(y_pred_1.sum() + y_true_1.sum())))
    
    y_pred_1 = y_pred==9
    y_true_1 = y_true[9,:,:,:]
    if y_pred_1.sum() + y_true_1.sum() == 0: avgdice.append(-1)
    else: avgdice.append(2.*(np.logical_and(y_pred_1, y_true_1).sum()) / (1.0*(y_pred_1.sum() + y_true_1.sum())))
    """
    for dice in avgdice: 
        if dice != -1:
            assert 0 <= dice <= 1
    return avgdice

In [13]:
#Anatomy training with tcia
model = ResNetUNET3D(SEBasicBlock3D, UpSEBasicBlock3D, UpBasicBlock3D, 2, num_classes=7+1, in_channel=1).to(device1) 
lossweight = np.array([2.22, 1.31, 1.99, 1.13, 1.93, 1.93, 1.0, 1.0], np.float32)
# pretraind_dict = t.load('./model/unet10pool3e2e_seres18_conc_pet_wmask_2_rmsp_1')["weight"]
# model_dict = model.state_dict()
# pretraind_dict = {k: v for k, v in pretraind_dict.items() if k in model_dict}
# model_dict.update(pretraind_dict)
# model.load_state_dict(pretraind_dict)
savename = 'AnatomyTciamodel/AnatomyTcia_'


# In[5]:

epoch=50
optimizer = t.optim.RMSprop(model.parameters(),lr = 5e-4)
maxloss = [0 for _ in range(7)]
for epoch in range(50):
    tq = tqdm(traindataloader, desc='loss', leave=True)
    trainloss = 0
    for x_train, y_train, flagvec in tq:
        x_train = t.autograd.Variable(x_train.to(device1))
        y_train = t.autograd.Variable(y_train.to(device1))
        optimizer.zero_grad()
        o = model(x_train)
        loss = tversky_loss_wmask(o, y_train, flagvec*t.from_numpy(lossweight))
        loss.backward()
        optimizer.step()
        tq.set_description("epoch %i loss %f" % (epoch, loss.item()))
        tq.refresh() # to show immediately the update
        trainloss += loss.item()
        del loss, x_train, y_train, o
    testtq = tqdm(testdataloader, desc='test loss', leave=True)
    testloss = [0 for _ in range(7)]
    for x_test, y_test, _ in testtq:
#         print(x_test.numpy().shape)
        with t.no_grad():
            x_test = t.autograd.Variable(x_test.to(device1))
#             y_test = t.autograd.Variable(y_test.to(device1))
        o = model(x_test)
        loss = caldice(o, y_test)
        testtq.set_description("epoch %i test loss %f" % (epoch, sum(loss)/7))
        testtq.refresh() # to show immediately the update
        testloss = [l+tl for l,tl in zip(loss, testloss)]
        del x_test, y_test, o
    testloss = [l / len(testtq) for l in testloss]
    for cls in range(7):
        if maxloss[cls] < testloss[cls]:
            maxloss[cls] = testloss[cls]
            state = {"epoch": epoch, "weight": model.state_dict()}
            t.save(state, savename+str(cls+1))
#             model.load_state_dict(t.load(savename)["weight"])
#             t.save(model, savename+str(cls+1))
    print('epoch %i TRAIN loss %.4f' % (epoch, trainloss/len(tq)))
    print('test loss %.4f, %.4f, %.4f, %.4f, %.4f, %.4f, %.4f' % tuple(testloss))
    print('best test loss %.4f, %.4f, %.4f, %.4f, %.4f, %.4f, %.4f' % tuple(maxloss))
   

epoch 0 loss 7.984834: 100%|██████████████████| 211/211 [10:09<00:00,  2.89s/it]
epoch 0 test loss 0.249903: 100%|███████████████| 20/20 [00:10<00:00,  1.91it/s]


epoch 0 TRAIN loss 7.4473
test loss 0.1995, 0.0057, 0.7975, 0.0000, 0.0000, 0.2612, 0.4298
best test loss 0.1995, 0.0057, 0.7975, 0.0000, 0.0000, 0.2612, 0.4298


epoch 1 loss 0.292073: 100%|██████████████████| 211/211 [09:57<00:00,  2.83s/it]
epoch 1 test loss 0.405132: 100%|███████████████| 20/20 [00:10<00:00,  1.97it/s]


epoch 1 TRAIN loss 5.8645
test loss 0.3773, 0.1835, 0.8378, 0.0000, 0.0000, 0.6491, 0.6957
best test loss 0.3773, 0.1835, 0.8378, 0.0000, 0.0000, 0.6491, 0.6957


epoch 2 loss 5.848918: 100%|██████████████████| 211/211 [09:52<00:00,  2.81s/it]
epoch 2 test loss 0.490518: 100%|███████████████| 20/20 [00:10<00:00,  1.96it/s]


epoch 2 TRAIN loss 5.2830
test loss 0.4987, 0.4834, 0.8612, 0.0358, 0.0000, 0.7222, 0.7546
best test loss 0.4987, 0.4834, 0.8612, 0.0358, 0.0000, 0.7222, 0.7546


epoch 3 loss 5.383348: 100%|██████████████████| 211/211 [09:56<00:00,  2.83s/it]
epoch 3 test loss 0.547645: 100%|███████████████| 20/20 [00:10<00:00,  1.92it/s]


epoch 3 TRAIN loss 4.8441
test loss 0.5073, 0.4741, 0.8257, 0.4065, 0.0000, 0.7759, 0.7709
best test loss 0.5073, 0.4834, 0.8612, 0.4065, 0.0000, 0.7759, 0.7709


epoch 4 loss 5.565170: 100%|██████████████████| 211/211 [10:06<00:00,  2.87s/it]
epoch 4 test loss 0.605635: 100%|███████████████| 20/20 [00:10<00:00,  1.93it/s]


epoch 4 TRAIN loss 4.2048
test loss 0.4881, 0.3892, 0.8514, 0.5633, 0.0000, 0.7901, 0.7626
best test loss 0.5073, 0.4834, 0.8612, 0.5633, 0.0000, 0.7901, 0.7709


epoch 5 loss 4.206728: 100%|██████████████████| 211/211 [10:00<00:00,  2.84s/it]
epoch 5 test loss 0.643972: 100%|███████████████| 20/20 [00:10<00:00,  1.95it/s]


epoch 5 TRAIN loss 3.9526
test loss 0.5202, 0.5536, 0.8731, 0.6662, 0.1170, 0.8023, 0.8033
best test loss 0.5202, 0.5536, 0.8731, 0.6662, 0.1170, 0.8023, 0.8033


epoch 6 loss 3.612428: 100%|██████████████████| 211/211 [09:58<00:00,  2.84s/it]
epoch 6 test loss 0.688150: 100%|███████████████| 20/20 [00:10<00:00,  1.98it/s]


epoch 6 TRAIN loss 3.2020
test loss 0.5353, 0.5288, 0.8715, 0.6168, 0.6166, 0.7734, 0.7189
best test loss 0.5353, 0.5536, 0.8731, 0.6662, 0.6166, 0.8023, 0.8033


epoch 7 loss 3.935063: 100%|██████████████████| 211/211 [09:43<00:00,  2.77s/it]
epoch 7 test loss 0.752614: 100%|███████████████| 20/20 [00:10<00:00,  1.95it/s]


epoch 7 TRAIN loss 2.7567
test loss 0.5368, 0.5804, 0.8751, 0.7138, 0.6395, 0.8073, 0.7959
best test loss 0.5368, 0.5804, 0.8751, 0.7138, 0.6395, 0.8073, 0.8033


epoch 8 loss 2.903270: 100%|██████████████████| 211/211 [09:44<00:00,  2.77s/it]
epoch 8 test loss 0.742846: 100%|███████████████| 20/20 [00:10<00:00,  1.94it/s]


epoch 8 TRAIN loss 2.7081
test loss 0.5432, 0.5154, 0.8844, 0.6947, 0.6836, 0.7919, 0.7772
best test loss 0.5432, 0.5804, 0.8844, 0.7138, 0.6836, 0.8073, 0.8033


epoch 9 loss 1.199904: 100%|██████████████████| 211/211 [09:44<00:00,  2.77s/it]
epoch 9 test loss 0.720515: 100%|███████████████| 20/20 [00:10<00:00,  1.96it/s]


epoch 9 TRAIN loss 2.6414
test loss 0.5433, 0.5646, 0.8795, 0.7136, 0.7074, 0.7856, 0.7545
best test loss 0.5433, 0.5804, 0.8844, 0.7138, 0.7074, 0.8073, 0.8033


epoch 10 loss 3.526790: 100%|█████████████████| 211/211 [09:44<00:00,  2.77s/it]
epoch 10 test loss 0.751851: 100%|██████████████| 20/20 [00:10<00:00,  1.94it/s]


epoch 10 TRAIN loss 2.5636
test loss 0.5343, 0.5569, 0.9045, 0.7120, 0.7272, 0.8143, 0.8270
best test loss 0.5433, 0.5804, 0.9045, 0.7138, 0.7272, 0.8143, 0.8270


epoch 11 loss 0.000000: 100%|█████████████████| 211/211 [09:45<00:00,  2.77s/it]
epoch 11 test loss 0.751360: 100%|██████████████| 20/20 [00:10<00:00,  1.93it/s]


epoch 11 TRAIN loss 2.4702
test loss 0.5555, 0.5879, 0.8993, 0.7072, 0.6874, 0.7997, 0.8035
best test loss 0.5555, 0.5879, 0.9045, 0.7138, 0.7272, 0.8143, 0.8270


epoch 12 loss 3.890705: 100%|█████████████████| 211/211 [09:45<00:00,  2.77s/it]
epoch 12 test loss 0.775343: 100%|██████████████| 20/20 [00:10<00:00,  1.95it/s]


epoch 12 TRAIN loss 2.4660
test loss 0.5606, 0.5926, 0.8961, 0.7286, 0.7191, 0.8052, 0.8259
best test loss 0.5606, 0.5926, 0.9045, 0.7286, 0.7272, 0.8143, 0.8270


epoch 13 loss 0.000000: 100%|█████████████████| 211/211 [09:42<00:00,  2.76s/it]
epoch 13 test loss 0.783261: 100%|██████████████| 20/20 [00:10<00:00,  1.96it/s]


epoch 13 TRAIN loss 2.3924
test loss 0.5690, 0.5770, 0.9065, 0.7360, 0.7370, 0.8250, 0.8372
best test loss 0.5690, 0.5926, 0.9065, 0.7360, 0.7370, 0.8250, 0.8372


epoch 14 loss 2.443136: 100%|█████████████████| 211/211 [09:45<00:00,  2.77s/it]
epoch 14 test loss 0.773287: 100%|██████████████| 20/20 [00:10<00:00,  1.95it/s]


epoch 14 TRAIN loss 2.3779
test loss 0.5715, 0.6215, 0.9126, 0.7228, 0.7355, 0.8206, 0.8372
best test loss 0.5715, 0.6215, 0.9126, 0.7360, 0.7370, 0.8250, 0.8372


epoch 15 loss 2.657015: 100%|█████████████████| 211/211 [09:43<00:00,  2.77s/it]
epoch 15 test loss 0.764545: 100%|██████████████| 20/20 [00:10<00:00,  1.96it/s]


epoch 15 TRAIN loss 2.3333
test loss 0.5686, 0.6080, 0.9118, 0.7231, 0.7395, 0.8226, 0.8228
best test loss 0.5715, 0.6215, 0.9126, 0.7360, 0.7395, 0.8250, 0.8372


epoch 16 loss 2.188157: 100%|█████████████████| 211/211 [09:44<00:00,  2.77s/it]
epoch 16 test loss 0.765111: 100%|██████████████| 20/20 [00:10<00:00,  1.96it/s]


epoch 16 TRAIN loss 2.3275
test loss 0.5712, 0.5931, 0.9113, 0.7174, 0.7379, 0.8300, 0.8405
best test loss 0.5715, 0.6215, 0.9126, 0.7360, 0.7395, 0.8300, 0.8405


epoch 17 loss 2.687555: 100%|█████████████████| 211/211 [09:43<00:00,  2.77s/it]
epoch 17 test loss 0.747076: 100%|██████████████| 20/20 [00:10<00:00,  1.94it/s]


epoch 17 TRAIN loss 2.3283
test loss 0.5724, 0.6224, 0.8984, 0.6333, 0.6108, 0.8102, 0.8330
best test loss 0.5724, 0.6224, 0.9126, 0.7360, 0.7395, 0.8300, 0.8405


epoch 18 loss 2.256000: 100%|█████████████████| 211/211 [09:43<00:00,  2.77s/it]
epoch 18 test loss 0.773976: 100%|██████████████| 20/20 [00:10<00:00,  1.94it/s]


epoch 18 TRAIN loss 2.2497
test loss 0.5751, 0.6149, 0.9043, 0.7315, 0.7472, 0.8321, 0.8375
best test loss 0.5751, 0.6224, 0.9126, 0.7360, 0.7472, 0.8321, 0.8405


epoch 19 loss 2.193202: 100%|█████████████████| 211/211 [09:44<00:00,  2.77s/it]
epoch 19 test loss 0.776798: 100%|██████████████| 20/20 [00:10<00:00,  1.94it/s]


epoch 19 TRAIN loss 2.2665
test loss 0.5751, 0.6239, 0.9086, 0.7245, 0.7355, 0.8178, 0.8390
best test loss 0.5751, 0.6239, 0.9126, 0.7360, 0.7472, 0.8321, 0.8405


epoch 20 loss 2.293175: 100%|█████████████████| 211/211 [09:44<00:00,  2.77s/it]
epoch 20 test loss 0.789531: 100%|██████████████| 20/20 [00:10<00:00,  1.95it/s]


epoch 20 TRAIN loss 2.2085
test loss 0.5752, 0.5727, 0.9062, 0.7132, 0.7394, 0.8328, 0.8455
best test loss 0.5752, 0.6239, 0.9126, 0.7360, 0.7472, 0.8328, 0.8455


epoch 21 loss 2.745738: 100%|█████████████████| 211/211 [09:45<00:00,  2.78s/it]
epoch 21 test loss 0.798468: 100%|██████████████| 20/20 [00:10<00:00,  1.95it/s]


epoch 21 TRAIN loss 2.2341
test loss 0.5803, 0.6236, 0.9135, 0.7190, 0.7520, 0.8372, 0.8427
best test loss 0.5803, 0.6239, 0.9135, 0.7360, 0.7520, 0.8372, 0.8455


epoch 22 loss 2.921762: 100%|█████████████████| 211/211 [09:44<00:00,  2.77s/it]
epoch 22 test loss 0.758040: 100%|██████████████| 20/20 [00:10<00:00,  1.93it/s]


epoch 22 TRAIN loss 2.2051
test loss 0.5656, 0.4578, 0.9065, 0.7175, 0.6910, 0.8352, 0.8478
best test loss 0.5803, 0.6239, 0.9135, 0.7360, 0.7520, 0.8372, 0.8478


epoch 23 loss 2.737626: 100%|█████████████████| 211/211 [09:43<00:00,  2.77s/it]
epoch 23 test loss 0.767900: 100%|██████████████| 20/20 [00:10<00:00,  1.91it/s]


epoch 23 TRAIN loss 2.1924
test loss 0.5815, 0.6082, 0.9116, 0.7302, 0.7281, 0.8285, 0.8370
best test loss 0.5815, 0.6239, 0.9135, 0.7360, 0.7520, 0.8372, 0.8478


epoch 24 loss 3.741282: 100%|█████████████████| 211/211 [09:44<00:00,  2.77s/it]
epoch 24 test loss 0.782621: 100%|██████████████| 20/20 [00:10<00:00,  1.96it/s]


epoch 24 TRAIN loss 2.1326
test loss 0.5786, 0.5968, 0.9161, 0.7320, 0.7472, 0.8385, 0.8568
best test loss 0.5815, 0.6239, 0.9161, 0.7360, 0.7520, 0.8385, 0.8568


epoch 25 loss 3.659197: 100%|█████████████████| 211/211 [09:45<00:00,  2.77s/it]
epoch 25 test loss 0.758596: 100%|██████████████| 20/20 [00:10<00:00,  1.94it/s]


epoch 25 TRAIN loss 2.1257
test loss 0.5778, 0.5182, 0.8988, 0.6973, 0.7117, 0.8134, 0.8391
best test loss 0.5815, 0.6239, 0.9161, 0.7360, 0.7520, 0.8385, 0.8568


epoch 26 loss 1.951533: 100%|█████████████████| 211/211 [09:44<00:00,  2.77s/it]
epoch 26 test loss 0.796122: 100%|██████████████| 20/20 [00:10<00:00,  1.93it/s]


epoch 26 TRAIN loss 2.1264
test loss 0.5789, 0.6198, 0.9208, 0.7383, 0.7474, 0.8374, 0.8459
best test loss 0.5815, 0.6239, 0.9208, 0.7383, 0.7520, 0.8385, 0.8568


epoch 27 loss 1.171332: 100%|█████████████████| 211/211 [09:45<00:00,  2.77s/it]
epoch 27 test loss 0.775100: 100%|██████████████| 20/20 [00:10<00:00,  1.95it/s]


epoch 27 TRAIN loss 2.1450
test loss 0.5848, 0.6345, 0.9098, 0.7317, 0.7260, 0.8297, 0.8488
best test loss 0.5848, 0.6345, 0.9208, 0.7383, 0.7520, 0.8385, 0.8568


epoch 28 loss 2.299464: 100%|█████████████████| 211/211 [09:45<00:00,  2.78s/it]
epoch 28 test loss 0.786360: 100%|██████████████| 20/20 [00:10<00:00,  1.94it/s]


epoch 28 TRAIN loss 2.0826
test loss 0.5827, 0.6250, 0.9165, 0.7334, 0.7495, 0.8377, 0.8561
best test loss 0.5848, 0.6345, 0.9208, 0.7383, 0.7520, 0.8385, 0.8568


epoch 29 loss 2.562873: 100%|█████████████████| 211/211 [09:45<00:00,  2.77s/it]
epoch 29 test loss 0.778038: 100%|██████████████| 20/20 [00:10<00:00,  1.96it/s]


epoch 29 TRAIN loss 2.0725
test loss 0.5810, 0.6199, 0.9215, 0.7227, 0.7520, 0.8378, 0.7782
best test loss 0.5848, 0.6345, 0.9215, 0.7383, 0.7520, 0.8385, 0.8568


epoch 30 loss 2.337221: 100%|█████████████████| 211/211 [09:44<00:00,  2.77s/it]
epoch 30 test loss 0.787444: 100%|██████████████| 20/20 [00:10<00:00,  1.94it/s]


epoch 30 TRAIN loss 2.0634
test loss 0.5790, 0.6086, 0.9191, 0.7289, 0.7525, 0.8411, 0.8559
best test loss 0.5848, 0.6345, 0.9215, 0.7383, 0.7525, 0.8411, 0.8568


epoch 31 loss 3.204741: 100%|█████████████████| 211/211 [09:45<00:00,  2.77s/it]
epoch 31 test loss 0.776545: 100%|██████████████| 20/20 [00:10<00:00,  1.92it/s]


epoch 31 TRAIN loss 2.0408
test loss 0.5774, 0.6310, 0.9166, 0.7326, 0.7528, 0.8349, 0.8516
best test loss 0.5848, 0.6345, 0.9215, 0.7383, 0.7528, 0.8411, 0.8568


epoch 32 loss 2.545722: 100%|█████████████████| 211/211 [09:44<00:00,  2.77s/it]
epoch 32 test loss 0.787384: 100%|██████████████| 20/20 [00:10<00:00,  1.96it/s]


epoch 32 TRAIN loss 2.0192
test loss 0.5841, 0.6116, 0.9170, 0.7380, 0.7450, 0.8408, 0.8585
best test loss 0.5848, 0.6345, 0.9215, 0.7383, 0.7528, 0.8411, 0.8585


epoch 33 loss 2.402507: 100%|█████████████████| 211/211 [09:43<00:00,  2.77s/it]
epoch 33 test loss 0.775999: 100%|██████████████| 20/20 [00:10<00:00,  1.93it/s]


epoch 33 TRAIN loss 2.0310
test loss 0.5798, 0.6245, 0.9213, 0.7176, 0.7405, 0.8394, 0.8489
best test loss 0.5848, 0.6345, 0.9215, 0.7383, 0.7528, 0.8411, 0.8585


epoch 34 loss 2.248990: 100%|█████████████████| 211/211 [09:40<00:00,  2.75s/it]
epoch 34 test loss 0.786481: 100%|██████████████| 20/20 [00:10<00:00,  1.95it/s]


epoch 34 TRAIN loss 2.0572
test loss 0.5829, 0.5888, 0.9244, 0.7229, 0.7579, 0.8431, 0.8573
best test loss 0.5848, 0.6345, 0.9244, 0.7383, 0.7579, 0.8431, 0.8585


epoch 35 loss 2.248512: 100%|█████████████████| 211/211 [09:41<00:00,  2.75s/it]
epoch 35 test loss 0.772107: 100%|██████████████| 20/20 [00:10<00:00,  1.95it/s]


epoch 35 TRAIN loss 2.0402
test loss 0.5847, 0.5901, 0.9183, 0.7384, 0.7552, 0.8311, 0.8507
best test loss 0.5848, 0.6345, 0.9244, 0.7384, 0.7579, 0.8431, 0.8585


epoch 36 loss 2.519341: 100%|█████████████████| 211/211 [09:41<00:00,  2.76s/it]
epoch 36 test loss 0.771159: 100%|██████████████| 20/20 [00:10<00:00,  1.93it/s]


epoch 36 TRAIN loss 2.0401
test loss 0.5836, 0.5894, 0.9217, 0.7345, 0.7579, 0.8444, 0.8472
best test loss 0.5848, 0.6345, 0.9244, 0.7384, 0.7579, 0.8444, 0.8585


epoch 37 loss 1.760215: 100%|█████████████████| 211/211 [09:41<00:00,  2.76s/it]
epoch 37 test loss 0.759888: 100%|██████████████| 20/20 [00:10<00:00,  1.93it/s]


epoch 37 TRAIN loss 2.0268
test loss 0.5853, 0.5519, 0.9099, 0.7338, 0.7454, 0.8452, 0.8605
best test loss 0.5853, 0.6345, 0.9244, 0.7384, 0.7579, 0.8452, 0.8605


epoch 38 loss 2.631798: 100%|█████████████████| 211/211 [09:39<00:00,  2.75s/it]
epoch 38 test loss 0.783921: 100%|██████████████| 20/20 [00:10<00:00,  1.95it/s]


epoch 38 TRAIN loss 1.9918
test loss 0.5779, 0.5922, 0.9208, 0.7221, 0.7496, 0.8308, 0.8457
best test loss 0.5853, 0.6345, 0.9244, 0.7384, 0.7579, 0.8452, 0.8605


epoch 39 loss 2.132298: 100%|█████████████████| 211/211 [09:40<00:00,  2.75s/it]
epoch 39 test loss 0.782668: 100%|██████████████| 20/20 [00:10<00:00,  1.95it/s]


epoch 39 TRAIN loss 2.0033
test loss 0.5778, 0.6081, 0.9226, 0.7214, 0.7607, 0.8446, 0.8613
best test loss 0.5853, 0.6345, 0.9244, 0.7384, 0.7607, 0.8452, 0.8613


epoch 40 loss 1.983316: 100%|█████████████████| 211/211 [09:40<00:00,  2.75s/it]
epoch 40 test loss 0.770030: 100%|██████████████| 20/20 [00:10<00:00,  1.93it/s]


epoch 40 TRAIN loss 1.9674
test loss 0.5819, 0.6161, 0.9154, 0.7094, 0.7479, 0.8396, 0.8623
best test loss 0.5853, 0.6345, 0.9244, 0.7384, 0.7607, 0.8452, 0.8623


epoch 41 loss 1.774611: 100%|█████████████████| 211/211 [09:40<00:00,  2.75s/it]
epoch 41 test loss 0.789446: 100%|██████████████| 20/20 [00:10<00:00,  1.93it/s]


epoch 41 TRAIN loss 1.9812
test loss 0.5842, 0.6116, 0.9248, 0.7332, 0.7505, 0.8402, 0.8528
best test loss 0.5853, 0.6345, 0.9248, 0.7384, 0.7607, 0.8452, 0.8623


epoch 42 loss 2.246955: 100%|█████████████████| 211/211 [09:40<00:00,  2.75s/it]
epoch 42 test loss 0.771342: 100%|██████████████| 20/20 [00:10<00:00,  1.93it/s]


epoch 42 TRAIN loss 1.9269
test loss 0.5837, 0.6218, 0.9258, 0.6953, 0.7540, 0.8387, 0.8507
best test loss 0.5853, 0.6345, 0.9258, 0.7384, 0.7607, 0.8452, 0.8623


epoch 43 loss 2.578094: 100%|█████████████████| 211/211 [09:39<00:00,  2.75s/it]
epoch 43 test loss 0.772983: 100%|██████████████| 20/20 [00:10<00:00,  1.94it/s]


epoch 43 TRAIN loss 1.9214
test loss 0.5842, 0.6038, 0.9222, 0.7264, 0.7576, 0.8461, 0.8634
best test loss 0.5853, 0.6345, 0.9258, 0.7384, 0.7607, 0.8461, 0.8634


epoch 44 loss 2.304584: 100%|█████████████████| 211/211 [09:40<00:00,  2.75s/it]
epoch 44 test loss 0.784981: 100%|██████████████| 20/20 [00:10<00:00,  1.95it/s]


epoch 44 TRAIN loss 1.9275
test loss 0.5899, 0.6131, 0.9262, 0.7362, 0.7513, 0.8478, 0.8627
best test loss 0.5899, 0.6345, 0.9262, 0.7384, 0.7607, 0.8478, 0.8634


epoch 45 loss 2.230659: 100%|█████████████████| 211/211 [09:40<00:00,  2.75s/it]
epoch 45 test loss 0.773996: 100%|██████████████| 20/20 [00:10<00:00,  1.97it/s]


epoch 45 TRAIN loss 1.9273
test loss 0.5796, 0.6233, 0.9178, 0.7208, 0.7495, 0.8365, 0.8518
best test loss 0.5899, 0.6345, 0.9262, 0.7384, 0.7607, 0.8478, 0.8634


epoch 46 loss 1.990067: 100%|█████████████████| 211/211 [09:40<00:00,  2.75s/it]
epoch 46 test loss 0.764314: 100%|██████████████| 20/20 [00:10<00:00,  1.94it/s]


epoch 46 TRAIN loss 1.9175
test loss 0.5833, 0.6031, 0.9136, 0.7276, 0.7591, 0.8377, 0.8628
best test loss 0.5899, 0.6345, 0.9262, 0.7384, 0.7607, 0.8478, 0.8634


epoch 47 loss 1.624497: 100%|█████████████████| 211/211 [09:40<00:00,  2.75s/it]
epoch 47 test loss 0.787135: 100%|██████████████| 20/20 [00:10<00:00,  1.94it/s]


epoch 47 TRAIN loss 1.8992
test loss 0.5889, 0.6206, 0.9277, 0.7275, 0.7377, 0.8453, 0.8629
best test loss 0.5899, 0.6345, 0.9277, 0.7384, 0.7607, 0.8478, 0.8634


epoch 48 loss 2.412229: 100%|█████████████████| 211/211 [09:39<00:00,  2.75s/it]
epoch 48 test loss 0.783324: 100%|██████████████| 20/20 [00:10<00:00,  1.93it/s]


epoch 48 TRAIN loss 1.8701
test loss 0.5875, 0.6077, 0.9234, 0.7061, 0.7542, 0.8444, 0.8652
best test loss 0.5899, 0.6345, 0.9277, 0.7384, 0.7607, 0.8478, 0.8652


epoch 49 loss 2.290135: 100%|█████████████████| 211/211 [09:39<00:00,  2.74s/it]
epoch 49 test loss 0.764614: 100%|██████████████| 20/20 [00:10<00:00,  1.94it/s]

epoch 49 TRAIN loss 1.8743
test loss 0.5911, 0.5957, 0.9237, 0.7176, 0.7246, 0.8437, 0.8571
best test loss 0.5911, 0.6345, 0.9277, 0.7384, 0.7607, 0.8478, 0.8652



