## Data preparation

In [None]:
#1
import os
import sys
# os.environ['KMP_DUPLICATE_LIB_OK']='True'
sys.path.append("./")

import cv2
import numpy as np
import pandas as pd
import random, tqdm
import seaborn as sns
import matplotlib.pyplot as plt
%matplotlib inline

import warnings
warnings.filterwarnings("ignore")

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import albumentations as album
import segmentation_models_pytorch as smp

from torch.utils.data import DataLoader
#from dataloaders.datasets import Pathology

import ssl
ssl._create_default_https_context = ssl._create_unverified_context

# sdm
from scipy.ndimage import distance_transform_edt as distance
from skimage import segmentation as skimage_seg
from skimage import morphology

# test loader
import itertools

from scipy.ndimage import distance_transform_edt

import math

from pytorch_grad_cam import GradCAM
from pytorch_grad_cam.utils.image import show_cam_on_image
from PIL import Image

In [None]:
#2
class Path(object):
    @staticmethod
    def pathology_root_dir():
        root_dir = ""
        return os.path.join(root_dir, "BCSS_patch")

### data visualization

In [None]:
#3
#helper function for data visualization
def visualize(**images):
    """
    Plot images in one row
    """
    n_images = len(images)
    plt.figure(figsize=(20,8))
    for idx, (name, image) in enumerate(images.items()):
        plt.subplot(1, n_images, idx + 1)
        plt.xticks([]); 
        plt.yticks([])
        # get title from the parameter names
        plt.title(name.replace('_',' ').title(), fontsize=20)
        plt.imshow(image)
    plt.show()

# Perform one hot encoding on label
def one_hot_encode(label, label_values):
    semantic_map = []
    for colour in label_values:
        equality = np.equal(label, colour)
        class_map = np.all(equality, axis = -1)
        semantic_map.append(class_map)
    semantic_map = np.stack(semantic_map, axis=-1)

    return semantic_map
    
# Perform reverse one-hot-encoding on labels / preds
def reverse_one_hot(image):
    x = np.argmax(image, axis = -1)
    return x

# Perform colour coding on the reverse-one-hot outputs
def colour_code_segmentation(image, label_values):
    colour_codes = np.array(label_values)
    x = colour_codes[image.astype(int)]

    return x

### dataset

In [None]:
#4
# Useful to shortlist specific classes in datasets with large number of classes
class_names = ['background', 'tumor ', 'stroma', 'dcis']
select_classes = ['background', 'tumor ', 'stroma', 'dcis']

background=[[0, 0, 0],]

# Get RGB values of required classes
class_rgb_values = [[0, 0, 0], [255, 0, 0], [0, 255, 0], [0, 0, 255]]

select_class_indices = [class_names.index(cls.lower()) for cls in select_classes]
select_class_rgb_values =  np.array(class_rgb_values)[select_class_indices]

print('Selected classes and their corresponding RGB values in labels:')
print('Class Names: ', class_names)
print('Class RGB values: ', class_rgb_values)

### Data augmentation

In [None]:
#5
def get_training_augmentation():
    train_transform = [    
        #album.RandomCrop(height=256, width=256, always_apply=True),
        album.OneOf(
            [
                album.HorizontalFlip(p=1),
                album.VerticalFlip(p=1),
                album.RandomRotate90(p=1),
            ],
            p=1,
        ),
    ]
    return album.Compose(train_transform)


def to_tensor(x, **kwargs):
    return x.transpose(2, 0, 1).astype('float32')


def get_preprocessing(preprocessing_fn=None):
    _transform = []
    if preprocessing_fn:
        _transform.append(album.Lambda(image=preprocessing_fn))
    _transform.append(album.Lambda(image=to_tensor, mask=to_tensor))
        
    return album.Compose(_transform)

In [None]:
class LabeledDataset(Dataset):
    def __init__(self, args, augmentation, split:str, img_list=None,  fraction: float=1.0, seed: int=42):
        self.augmentation = augmentation
        self.args = args
        self.data_dir = os.path.join(Path.pathology_root_dir(), 'labelled', split, 'png')

        if img_list is not None:
            self.imgs = img_list
            
        else:
            all_imgs = sorted([os.path.join(self.data_dir, f) for f in os.listdir(self.data_dir)])
            if fraction <= 1.0:
                np.random.seed(seed)
                selected_indices = np.random.choice(len(all_imgs), int(len(all_imgs)*fraction), replace=False)
                self.imgs = [all_imgs[i] for i in sorted(selected_indices)]
            else: 
                self.imgs = all_imgs
    
    def __len__(self):
        return len(self.imgs)
    
    def __getitem__(self, index):
        _img = cv2.cvtColor(cv2.imread(self.imgs[index]), cv2.COLOR_BGR2RGB)
        mask_path = self.imgs[index].replace("png", "npy")
        _label = np.load(mask_path, allow_pickle=True).item().get('label')

        if self.augmentation != None:
            sample = self.augmentation(image=_img, mask=_label)
            _img, _label = sample['image'], sample['mask']
        
        _img = _img/255
        transform = A.Compose([ToTensorV2()])
        _img = transform(image=_img)
        _label = torch.as_tensor(_label).long()

        sample = {
            'img': _img,
            'label':_label,
            'path': mask_path
        }
        return sample

In [None]:
train_aug_dataset = LabeledDataset([], get_training_augmentation(), "train", None, 0.5)

labeled_img_list = train_aug_dataset.imgs
print(len(train_aug_dataset))

In [None]:
def collate_fn(batch):   
    batch = [b for b in batch if b is not None]
    return torch.utils.data.dataloader.default_collate(batch)

In [None]:
# generate unlabeled dataset
class UnlabeledDataset(Dataset):
    def __init__(self, args, augmentation, split: str, img_list):
        self.augmentation = augmentation
        self.args = args
        self.data_dir = os.path.join(Path.pathology_root_dir(), 'labelled', split, 'png')

        exclude_files = set([os.path.basename(path) for path in img_list])

        all_imgs = sorted(os.listdir(self.data_dir))
        remaining_imgs = [f for f in all_imgs if f not in exclude_files]

        self.imgs = [os.path.join(self.data_dir, f) for f in remaining_imgs]
    
    def __len__(self):
        return len(self.imgs)

    def __getitem__(self, index):
        _img = cv2.cvtColor(cv2.imread(self.imgs[index]), cv2.COLOR_BGR2RGB)
        mask_path = self.imgs[index]

        _img = _img / 255
        transform = A.Compose([ToTensorV2()])
        _img = transform(image=_img)

        sample = {
            'img': _img,
            'path': mask_path
        }
        return sample

In [None]:
# generate train_loader

def labeled_make_loaders(args, num_workers, pin_memory=True):
    train_set = LabeledDataset([], None, "train", labeled_img_list)+train_aug_dataset
    test_set = LabeledDataset([], None, "test")
    unlabel_set = UnlabeledDataset([], None, "train", labeled_img_list)
    train_set, val_set = torch.utils.data.random_split(train_set, [7356, 1840])

    print(f"train set: {len(train_set)} | val set: {len(val_set)} | test set: {len(test_set)} | unlabel set: {len(unlabel_set)}")

    train_loader = DataLoader(
        train_set, batch_size=16, shuffle = True, 
        num_workers=8, pin_memory=True,
    )

    val_loader = DataLoader(
        val_set, batch_size = 16, shuffle=True,
        num_workers=8, pin_memory=True,
    )

    test_loader = DataLoader(
        test_set, shuffle=False, num_workers=10, pin_memory=True
    )

    unlabel_loader = DataLoader(
        unlabel_set, shuffle=False, num_workers=10, pin_memory=True
    )


    return train_loader, val_loader, test_loader, unlabel_loader

label_train_loader, label_val_loader, test_loader, unlabel_loader = labeled_make_loaders([], num_workers=8, pin_memory=True)

### Parameter setting

In [None]:
#9
#label smoothing
class LabelSmoothingLoss(nn.Module):
    def __init__(self, classes, smoothing=0.0, dim=-1):
        super(LabelSmoothingLoss, self).__init__()
        self.confidence = 1.0 - smoothing
        self.smoothing = smoothing
        self.cls = classes
        self.dim = dim

    def forward(self, pred, target):
        assert 0 <= self.smoothing < 1
        pred = pred.log_softmax(dim=self.dim)
        with torch.no_grad():
            true_dist = torch.zeros_like(pred)
            true_dist.fill_(self.smoothing / (self.cls - 1))
            true_dist.scatter_(1, target.data.unsqueeze(1), self.confidence)
        return torch.mean(torch.sum(-true_dist * pred, dim=self.dim))

In [None]:
#10
def accuracy(dataloader):
    correct = 0
    total = 0
    with torch.no_grad():
        model.eval()
        for data in dataloader:
            inputs = data['img']['image']
            labels = data['label']
            images, labels = inputs.float().to(DEVICE), labels.to(DEVICE, dtype=torch.int64)
            outputs = model(images)
            #outputs = torch.softmax(outputs, dim=1)
            _, predicted = torch.max(outputs.detach(), 1)
            total += labels.size(0)      
            correct += (predicted == labels).sum().item()

    acc = 100*correct/total
    model.train()
    return acc

### model configuration

model listup
- semi-supervised
    - Unet
        - encoder: resnet50
        - encoder-weights: none
    - Unet
        - encoder: timm-resnest50d
    - Deeplabv3+
        - encoder: resnet50
    - Segformer
        - mitb2
    - Unet : with sdm
        - encoder: resnet50
        - encoder-weights: imagenet
        - activation: identity
    - Unet : with sdm
        - encoder: resnest50
        - encoder-weights: imagenet
        - activation: identity

In [None]:
#11
import segmentation_models_pytorch as smp
ENCODER = 'timm-resnest50d' #'timm-resnest50d' 'densenet201' 'resnet34'
ENCODER_WEIGHTS = None
CLASSES = class_names
ACTIVATION ='softmax2d'  #'identity' # 'softmax2d'
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

print(DEVICE)

model = smp.Unet(
    encoder_name=ENCODER, 
    encoder_weights=ENCODER_WEIGHTS, 
    classes=len(CLASSES), 
    activation= ACTIVATION,
)

model = model.to(DEVICE)

In [None]:
#12
# Set flag to train the model or not. If set to 'False', only prediction is performed (using an older model checkpoint)
TRAINING = True

iterations = 80

print("Device : ",DEVICE)

criterion = LabelSmoothingLoss(3, 0.2)

optimizer = torch.optim.Adam([ 
    dict(params=model.parameters(), lr=0.0001, betas=(0.92, 0.99)),
])

lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
    optimizer, T_max=7, eta_min=0.0001,
)


In [None]:
#13
import time
trainLoss = []
valLoss = []
#valDice = []
start = time.time()


alpha = 0
beta = 0.3
alpha_t = 1e-4
T1 = 5
T2 = 70
model.to(DEVICE)

In [None]:
# sdm map calculation
def compute_sdf(img_gt, out_shape):

    img_gt = img_gt.astype(np.uint8)
    normalized_sdf = np.zeros(out_shape)
    # thresh = 15

    for b in range(out_shape[0]):  # batch size
    # Foreground mask: cls1 or cls2
        posmask = np.isin(img_gt[b], [1, 2]).astype(bool)  # cls1과 cls2를 foreground로 설정
        # posmask = np.isin(img_gt[b], [1]).astype(bool)
        if posmask.any():
        
            negmask = ~posmask
            posdis = distance(posmask)
            negdis = distance(negmask)
      
            boundary = skimage_seg.find_boundaries(posmask, mode='inner').astype(np.uint8)
         
            max_posdis = np.max(posdis)
            max_negdis = np.max(negdis)
            if max_posdis == 0 or max_negdis == 0:
                sdf = np.zeros_like(posdis)
            else:
                sdf = (negdis - np.min(negdis)) / (np.max(negdis) - np.min(negdis)) - (posdis - np.min(posdis)) / (np.max(posdis) - np.min(posdis))
                sdf[boundary == 1] = 0

            normalized_sdf[b] = sdf 


    return normalized_sdf
    



## Training

### pre-training

In [None]:
#14: semi-supervised learning: pre-training
torch.backends.cudnn.benchmark = False
print(f"pretraining starts!!")
for epoch in range(pre_iterations):
    print(f"{epoch} iteration start!")
    best_score = 999999
    epochStart = time.time()
    runningLoss = 0
    train_i_number = 0
    val_i_number = 0
    correct = 0
    total = 0
    best_acc = 0
    model.train() # For training
    print(f'train loader: {len(label_train_loader)}')
    for traindata in tqdm.notebook.tqdm(label_train_loader):
        train_inputs = traindata['img']['image']
        train_labels = traindata['label']

        train_inputs, train_labels = train_inputs.float().to(DEVICE), train_labels.to(device=DEVICE, dtype=torch.int64)

        optimizer.zero_grad()

        pred = model(train_inputs)
        train_outputs = F.softmax(pred, dim=1)

        # compute sdm
        pred_sdm = torch.tanh(pred)

        label_sdm = torch.from_numpy(compute_sdm(train_labels)).float().to(pred_sdm.device)
        # print(f'pred shape: {pred_sdm.shape}, label shape: {label_sdm.shape}')
        sdm_loss = F.mse_loss(pred_sdm, label_sdm)
    
        # compute loss
        loss = criterion(train_outputs, train_labels) + (beta*sdm_loss)

        loss.backward()

        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()

        _, predicted = torch.max(train_outputs.detach(), 1)
        total += train_labels.size(0)
        correct += (predicted == train_labels).sum().item()
    val_acc = accuracy(label_val_loader)

    if val_acc >= best_acc:
        best_acc = val_acc
        torch.save(model.state_dict(), './pretrained/sdm_unet_resnest50.pth')
        print('[%d] train acc: %.2f, validation acc: %.2f - Saved the best model' %(epoch, 100*correct/total, val_acc))
        print("*******") 
    elif epoch % 10 == 0:
        print('[%d] train acc: %.2f, validation acc: %.2f' %(epoch, 100*correct/total, val_acc))

### save pseudo-labels

In [None]:
# save pseudo-labels

model.eval()

with torch.no_grad():
    for batch in tqdm.notebook.tqdm(unlabel_loader):
        unlabel_input = batch['img']['image']
        unlabel_input = unlabel_input.float().to(DEVICE)
        path = os.path.basename(batch['path'][0]).replace('png', 'npy')

        unlabel_output = model(unlabel_input) 
        
        unlabel_output = F.softmax(unlabel_output, dim=1)
        _, pred_mask = torch.max(unlabel_output, 1)
        pred_mask = pred_mask[0].cpu().numpy()
        image_np = unlabel_input.cpu().numpy()

        save_dict= {
            'input': image_np, 
            'label': pred_mask
        }

        save_path = os.path.join(Path.pathology_root_dir(), 'pseudo', 'sdm_unet_resnest50', 'npy')
        os.makedirs(save_path, exist_ok=True) 
        save_file = os.path.join(save_path, path)
        np.save(save_file, save_dict)

### pseudo-label dataset load

In [None]:
class PseudoDataset(Dataset):
    def __init__(self, args, augmentation, split:str, img_list):
        
        self.augmentation = augmentation
        self.args = args
        self.img_dir = os.path.join(Path.pathology_root_dir(), 'labelled', split, 'png')
        
        exclude_files = set([os.path.basename(p) for p in img_list])

        all_imgs = sorted(os.listdir(self.img_dir))

        
        remaining_imgs = [f for f in all_imgs if f not in exclude_files]

        self.imgs = [os.path.join(self.img_dir, f) for f in remaining_imgs]
    

    def __len__(self):
        return len(self.imgs)
    
    def __getitem__(self, index):
        _img = cv2.cvtColor(cv2.imread(self.imgs[index]), cv2.COLOR_BGR2RGB)
        file_name = os.path.basename(self.imgs[index]).replace("png", "npy")
        label_path = os.path.join(Path.pathology_root_dir(), 'pseudo','sdm_unet_resnest50', 'npy', file_name)
        _label = np.load(label_path, allow_pickle=True).item().get('label')

        if self.augmentation != None:
            sample = self.augmentation(image= _img, mask = _label)
            _img, _label = sample['image'], sample['mask']

        _img = _img/255
        transform = A.Compose([ToTensorV2()])
        _img = transform(image=_img)
        _label = torch.as_tensor(_label).long()

        sample = {
            'img': _img,
            'label': _label, 
            'path': label_path
        }
        return sample

In [None]:
pseudo_aug_dataset = PseudoDataset([], get_training_augmentation(), "train", labeled_img_list)

print(len(pseudo_aug_dataset))

In [None]:
train_aug_dataset = LabeledDataset([], get_training_augmentation(), "train", labeled_img_list, 0.5)

print(len(train_aug_dataset))

In [None]:
# labeled+pseudo data loader
def whole_make_loaders(args, num_workers, pin_memory=True):
    pseudo_set = PseudoDataset([], None, "train", labeled_img_list) +pseudo_aug_dataset
    label_set = LabeledDataset([], None, "train", labeled_img_list) +train_aug_dataset
    

    test_set = LabeledDataset([], None, "test")

    total_len = len(label_set)
    train_len = int(total_len*0.8)
    val_len = total_len - train_len

    pseudo_len = len(pseudo_set)
    pseudo_train_len = int(pseudo_len*0.8)
    pseudo_val_len = pseudo_len - pseudo_train_len

    labeled_train_set, labeled_val_set = torch.utils.data.random_split(label_set, [train_len, val_len])
    pseudo_train_set, pseudo_val_set = torch.utils.data.random_split(pseudo_set, [pseudo_train_len, pseudo_val_len])

    val_set = pseudo_val_set + labeled_val_set

    print(f'labeled train set: {len(labeled_train_set)} | pseudo train set: {len(pseudo_train_set)} | val set: {len(val_set)} | test_set: {len(test_set)}')


    train_loader = DataLoader(
        labeled_train_set, batch_size=16, shuffle=True,
        num_workers=8, pin_memory=True,
    )

    pseudo_train_loader = DataLoader(
        pseudo_train_set, batch_size=16, shuffle=True,
        num_workers=8, pin_memory=True,
    )

    val_loader = DataLoader(
        val_set, batch_size=16, shuffle=False,
        num_workers=8, pin_memory=True,
    )

    test_loader = DataLoader(
        test_set, shuffle=False, num_workers=10, pin_memory=True,
    )
    
    return train_loader, pseudo_train_loader, val_loader, test_loader

train_loader, pseudo_train_loader, val_loader, test_loader = whole_make_loaders([], num_workers=8, pin_memory=True)

In [None]:
# label+pseudo train : without sdm
torch.cuda.empty_cache()
torch.backends.cudnn.benchmark = False
for epoch in range(iterations):
    print(epoch,"epoch start !")
    best_score=999999
    epochStart = time.time()
    runningLoss = 0
    train_i_number = 0
    val_i_number = 0
    correct = 0
    total = 0
    best_acc = 0
    model.train() # For training
    print("train_loader", len(train_loader))
    # for traindata in zip(train_loader):
    for traindata, pseudodata in zip(train_loader, pseudo_train_loader):

        train_inputs = traindata['img']['image']
        train_labels = traindata['label']

        pseudo_inputs = pseudodata['img']['image']
        pseudo_labels = pseudodata['label']

        train_inputs, train_labels = train_inputs.float().to(DEVICE), train_labels.to(device=DEVICE, dtype=torch.int64)
        pseudo_inputs, pseudo_labels = pseudo_inputs.float().to(DEVICE), pseudo_labels.to(device=DEVICE, dtype=torch.int64) 
        #print("**") 
        
    
        # Initialize gradients to zero
        optimizer.zero_grad()  
        # Feed-forward input data through the network
        #print("***")
        train_outputs = model(train_inputs)

        if alpha > 0: 
            pseudo_outputs = model(pseudo_inputs)
            _, pseudo_labels = torch.max(pseudo_outputs.detach(), 1)     
            loss = criterion(train_outputs, train_labels)  + alpha*criterion(pseudo_outputs, pseudo_labels)
            #print("****")
        else:
            loss = criterion(train_outputs, train_labels)

        # Backpropagate loss and compute gradients
        loss.backward()
        # Update the network parameters
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()
        _, predicted = torch.max(train_outputs.detach(), 1)
        total += train_labels.size(0)
        correct += (predicted == train_labels).sum().item()

    val_acc = accuracy(val_loader)
    if val_acc >= best_acc:
        best_acc = val_acc
        torch.save(model.state_dict(), './[BCSS]segmentation_model/sdm_semi_unet_resnest50.pth')    
        print('[%d] train acc: %.2f, validation acc: %.2f - Saved the best model' %(epoch, 100*correct/total, val_acc))  
        
        print("*******")
    elif epoch % 10 == 0:
        print('[%d] train acc: %.2f, validation acc: %.2f' %(epoch, 100*correct/total, val_acc))
        
        The 

        print("*******")

In [None]:
# pseudo-label training: with SDM
torch.backends.cudnn.benchmark = False
for epoch in range(iterations):
    print(epoch,"epoch start !")
    best_score=999999
    epochStart = time.time()
    runningLoss = 0
    train_i_number = 0
    val_i_number = 0
    correct = 0
    total = 0
    best_acc = 0
    model.train() # For training
    print("train_loader", len(train_loader))
    print("unlabel_loader", len(unlabel_loader))
    for traindata, pseudodata in zip(train_loader, unlabel_loader):

        train_inputs = traindata['img']['image']
        train_labels = traindata['label']

        pseudo_inputs = pseudodata['img']['image']
        pseudo_labels = pseudodata['label']

        train_inputs, train_labels = train_inputs.float().to(DEVICE), train_labels.to(device=DEVICE, dtype=torch.int64)
        pseudo_inputs, pseudo_labels = pseudo_inputs.float().to(DEVICE), pseudo_labels.to(device=DEVICE, dtype=torch.int64)

    
        # Initialize gradients to zero
        optimizer.zero_grad()  
        # Feed-forward input data through the network

        pred = model(train_inputs)
        train_outputs = F.softmax(pred, dim=1)
        print(f"pred: {pred[0]}")
        print(f"target shape: {train_labels[0]}")

        # sdm loss
        #pred = torch.log(train_outputs)
        #pred = pred - pred.mean()
        pred_sdm = torch.tanh(pred)

        
        with torch.no_grad():
            label_sdm = compute_sdf(train_labels.cpu().numpy(), pred_sdm.shape)
            label_sdm = torch.from_numpy(label_sdm).float().to(device=DEVICE)
        sdm_loss = F.l2_loss(pred_sdm, label_sdm) #l1_loss(pred_tanh, pred_sdm) #mse_loss()

        if alpha > 0:
            pseudo_pred = model(pseudo_inputs)
            pseudo_outputs = F.softmax(pseudo_pred, dim=1)
            _, pseudo_labels = torch.max(pseudo_outputs.detach(), 1)

  
            pseudo_sdm = torch.tanh(pseudo_pred)

            with torch.no_grad():
                pseudo_label_sdm = compute_sdf(pseudo_labels.cpu().numpy(), pseudo_sdm.shape)
                pseudo_label_sdm = torch.from_numpy(pseudo_label_sdm).float().to(device=DEVICE)
            pseudo_sdm_loss = F.l2_loss(pseudo_sdm, pseudo_label_sdm) #l1_loss(pseudo_tanh, pseudo_sdm) #mse_loss
            loss = criterion(train_outputs, train_labels)  + alpha*criterion(pseudo_outputs, pseudo_labels) + beta*(sdm_loss+pseudo_sdm_loss)
   
        else:
            loss = criterion(train_outputs, train_labels) + beta*sdm_loss

        # Backpropagate loss and compute gradients
        loss.backward()
        # Update the network parameters
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()

        #print("****")

        _, predicted = torch.max(train_outputs.detach(), 1)
        total += train_labels.size(0)
        correct += (predicted == train_labels).sum().item()

        #print("*****")
    
        if (epoch > T1) and (epoch < T2):
            alpha = alpha_t*(epoch - T1)/(T2 - T1)               
            #print("******")

        elif epoch >= T2: 
            alpha = alpha_t
            #print("******")
    val_acc = accuracy(val_loader)
    if val_acc >= best_acc:
        best_acc = val_acc
        torch.save(model.state_dict(), './model/resnest_pseudo5.pth')    
        print('[%d] train acc: %.2f, validation acc: %.2f - Saved the best model' %(epoch, 100*correct/total, val_acc))  
        wandb.log({"train acc": 100*correct/total, "validation acc":val_acc})
        print("*******")
    elif epoch % 10 == 0:
        print('[%d] train acc: %.2f, validation acc: %.2f' %(epoch, 100*correct/total, val_acc))
        wandb.log({"train acc": 100*correct/total, "validation acc":val_acc})

        print("*******")
